From 761ad266f4a01e1b4978610ba9989738b5c65ca3 Mon Sep 17 00:00:00 2001 From: zggsong Date: Thu, 9 Feb 2023 16:42:21 +0800 Subject: [PATCH] feat: add login func --- server/common/config.go | 6 +++ server/common/db.go | 7 ++++ server/common/jwt.go | 44 ++++++++++++++++++++ server/common/response.go | 23 +++++++++++ server/controller/info.go | 13 ++++++ server/controller/login.go | 50 +++++++++++++++++++++++ server/controller/register.go | 61 ++++++++++++++++++++++++++++ server/dto/user_dto.go | 15 +++++++ server/global/global.go | 11 ----- server/go.mod | 3 +- server/go.sum | 2 + server/main.go | 9 +---- server/middleware/authmiddleware.go | 63 +++++++++++++++++++++++++++++ server/router/router.go | 12 ++++-- server/util/util.go | 22 ++++++++++ 15 files changed, 318 insertions(+), 23 deletions(-) create mode 100644 server/common/jwt.go create mode 100644 server/common/response.go create mode 100644 server/controller/info.go create mode 100644 server/controller/login.go create mode 100644 server/controller/register.go create mode 100644 server/dto/user_dto.go delete mode 100644 server/global/global.go create mode 100644 server/middleware/authmiddleware.go create mode 100644 server/util/util.go diff --git a/server/common/config.go b/server/common/config.go index 63a3541..bd3edd7 100644 --- a/server/common/config.go +++ b/server/common/config.go @@ -6,6 +6,8 @@ import ( "os" ) +var Config model.Config + func ConfigInit() (model.Config, error) { var config model.Config @@ -27,3 +29,7 @@ func ConfigInit() (model.Config, error) { return config, nil } + +func GetConfig() model.Config { + return Config +} diff --git a/server/common/db.go b/server/common/db.go index bc0e3ff..4cb1f38 100644 --- a/server/common/db.go +++ b/server/common/db.go @@ -8,6 +8,8 @@ import ( "gorm.io/gorm" ) +var DB *gorm.DB + func InitDB(conf model.DataSource) (*gorm.DB, error) { // 拼接下 dsn 参数, dsn 格式可以参考上面的语法 //这里使用 Sprintf 动态拼接 dsn 参数,因为一般数据库连接参数,我们都是保存在配置文件里面,需要从配置文件加载参数,然后拼接 dsn。 @@ -29,5 +31,10 @@ func InitDB(conf model.DataSource) (*gorm.DB, error) { return nil, err } + DB = db return db, nil } + +func GetDB() *gorm.DB { + return DB +} diff --git a/server/common/jwt.go b/server/common/jwt.go new file mode 100644 index 0000000..a13e135 --- /dev/null +++ b/server/common/jwt.go @@ -0,0 +1,44 @@ +package common + +import ( + "expenses/model" + "github.com/dgrijalva/jwt-go" + "time" +) + +var jwtKey = []byte("a_secret_key") + +type Claims struct { + UserId uint + jwt.StandardClaims +} + +func ReleaseToken(user model.User) (string, error) { + // 设置Token过期时间 + expirationTime := time.Now().Add(7 * 24 * time.Hour) + claims := Claims{ + UserId: user.ID, + StandardClaims: jwt.StandardClaims{ + ExpiresAt: expirationTime.Unix(), + IssuedAt: time.Now().Unix(), + Issuer: "github.com/zggsong/expenses", + Subject: "user token", + }, + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenString, err := token.SignedString(jwtKey) + if err != nil { + return "", err + } + + return tokenString, nil +} + +func ParseToken(tokenString string) (*jwt.Token, *Claims, error) { + claims := &Claims{} + + token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) { + return jwtKey, nil + }) + return token, claims, err +} diff --git a/server/common/response.go b/server/common/response.go new file mode 100644 index 0000000..5f04945 --- /dev/null +++ b/server/common/response.go @@ -0,0 +1,23 @@ +package common + +import ( + "net/http" + + "github.com/gin-gonic/gin" +) + +func Response(ctx *gin.Context, httpStatus int, code int, data gin.H, msg string) { + ctx.JSON(httpStatus, gin.H{ + "code": code, + "data": data, + "msg": msg, + }) +} + +func Success(ctx *gin.Context, data gin.H, msg string) { + Response(ctx, http.StatusOK, 200, data, msg) +} + +func Fail(ctx *gin.Context, data gin.H, msg string) { + Response(ctx, http.StatusOK, 400, data, msg) +} diff --git a/server/controller/info.go b/server/controller/info.go new file mode 100644 index 0000000..a0d20f0 --- /dev/null +++ b/server/controller/info.go @@ -0,0 +1,13 @@ +package controller + +import ( + "expenses/common" + "expenses/dto" + "expenses/model" + "github.com/gin-gonic/gin" +) + +func Info(ctx *gin.Context) { + user, _ := ctx.Get("user") + common.Success(ctx, gin.H{"user": dto.ToUserDto(user.(model.User))}, "") +} diff --git a/server/controller/login.go b/server/controller/login.go new file mode 100644 index 0000000..2dd5caa --- /dev/null +++ b/server/controller/login.go @@ -0,0 +1,50 @@ +package controller + +import ( + "expenses/common" + "expenses/model" + "github.com/gin-gonic/gin" + "golang.org/x/crypto/bcrypt" + "net/http" +) + +func Login(ctx *gin.Context) { + DB := common.GetDB() + // 获取参数 + telephone := ctx.Query("telephone") + password := ctx.Query("password") + + // 数据验证 + if len(telephone) != 11 { + common.Fail(ctx, nil, "手机号必须是11位") + return + } + if len(password) < 6 { + common.Fail(ctx, nil, "密码必须大于6位") + return + } + + // 判断手机号是否存在 + var user model.User + DB.Where("telephone = ?", telephone).First(&user) + if user.ID == 0 { + common.Fail(ctx, nil, "用户不存在") + return + } + + // 密码验证 + if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil { + common.Fail(ctx, nil, "密码错误") + return + } + + // 发放Token + token, err := common.ReleaseToken(user) + if err != nil { + common.Response(ctx, http.StatusInternalServerError, 500, nil, "发放Token失败") + return + } + + // 返回结果 + common.Success(ctx, gin.H{"token": token}, "登录成功") +} diff --git a/server/controller/register.go b/server/controller/register.go new file mode 100644 index 0000000..1a86b6a --- /dev/null +++ b/server/controller/register.go @@ -0,0 +1,61 @@ +package controller + +import ( + "expenses/common" + "expenses/model" + "expenses/util" + "github.com/gin-gonic/gin" + "golang.org/x/crypto/bcrypt" + "gorm.io/gorm" + "net/http" +) + +func Register(ctx *gin.Context) { + DB := common.GetDB() + // 获取参数 + name := ctx.Query("name") + telephone := ctx.Query("telephone") + password := ctx.Query("password") + + // 数据验证 + if len(telephone) != 11 { + common.Fail(ctx, nil, "手机号必须是11位") + return + } + if len(password) < 6 { + common.Fail(ctx, nil, "密码必须大于6位") + return + } + // 如果名字为空则返回10为随机字符串 + if name == "" { + name = util.RandomString(10) + } + + // 判断手机号是否存在 + if isExistTelephone(DB, telephone) { + common.Fail(ctx, nil, "用户已存在") + return + } + + // 创建用户 + hasedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + common.Response(ctx, http.StatusInternalServerError, 500, nil, "密码加密失败") + return + } + newUser := model.User{ + Name: name, + Telephone: telephone, + Password: string(hasedPassword), + } + DB.Create(&newUser) + + // 返回结果 + common.Success(ctx, nil, "注册成功") +} + +func isExistTelephone(db *gorm.DB, telephone string) bool { + var user model.User + db.Where("telephone = ?", telephone).First(&user) + return user.ID != 0 +} diff --git a/server/dto/user_dto.go b/server/dto/user_dto.go new file mode 100644 index 0000000..85e2528 --- /dev/null +++ b/server/dto/user_dto.go @@ -0,0 +1,15 @@ +package dto + +import "expenses/model" + +type UserDto struct { + Name string `json:"name"` + Telephone string `json:"telephone"` +} + +func ToUserDto(user model.User) UserDto { + return UserDto{ + Name: user.Name, + Telephone: user.Telephone, + } +} diff --git a/server/global/global.go b/server/global/global.go deleted file mode 100644 index c0add84..0000000 --- a/server/global/global.go +++ /dev/null @@ -1,11 +0,0 @@ -package global - -import ( - "expenses/model" - "gorm.io/gorm" -) - -var ( - GLO_CONF model.Config - GLO_DB *gorm.DB -) diff --git a/server/go.mod b/server/go.mod index d1aaca7..98ae3a7 100644 --- a/server/go.mod +++ b/server/go.mod @@ -3,8 +3,10 @@ module expenses go 1.19 require ( + github.com/dgrijalva/jwt-go v3.2.0+incompatible github.com/gin-gonic/gin v1.8.2 github.com/spf13/viper v1.15.0 + golang.org/x/crypto v0.4.0 gorm.io/driver/postgres v1.4.7 gorm.io/gorm v1.24.5 ) @@ -36,7 +38,6 @@ require ( github.com/spf13/pflag v1.0.5 // indirect github.com/subosito/gotenv v1.4.2 // indirect github.com/ugorji/go/codec v1.2.7 // indirect - golang.org/x/crypto v0.4.0 // indirect golang.org/x/net v0.4.0 // indirect golang.org/x/sys v0.3.0 // indirect golang.org/x/text v0.5.0 // indirect diff --git a/server/go.sum b/server/go.sum index 9f78949..0a3931d 100644 --- a/server/go.sum +++ b/server/go.sum @@ -50,6 +50,8 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= +github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= diff --git a/server/main.go b/server/main.go index aa50972..5a9dd64 100644 --- a/server/main.go +++ b/server/main.go @@ -2,7 +2,6 @@ package main import ( "expenses/common" - "expenses/global" "expenses/router" "github.com/gin-gonic/gin" "os" @@ -20,21 +19,17 @@ func main() { } /*数据库初始化*/ - initDB, err := common.InitDB(conf.DataSource) + _, err = common.InitDB(conf.DataSource) if err != nil { common.Log.Printf("数据库初始化失败!") os.Exit(-2) } - /*全局变量初始化*/ - global.GLO_CONF = conf - global.GLO_DB = initDB - /*Gin初始化*/ //gin.SetMode(gin.ReleaseMode) r := gin.Default() r = router.CollectRoute(r) - port := global.GLO_CONF.Server.Port + port := conf.Server.Port if port == "" { port = "8080" } diff --git a/server/middleware/authmiddleware.go b/server/middleware/authmiddleware.go new file mode 100644 index 0000000..5a2430f --- /dev/null +++ b/server/middleware/authmiddleware.go @@ -0,0 +1,63 @@ +package middleware + +import ( + "expenses/common" + "expenses/model" + "net/http" + "strings" + + "github.com/gin-gonic/gin" +) + +// AuthMiddleWare +// +// @Description: 认证中间件 +// @return gin.HandlerFunc +func AuthMiddleWare() gin.HandlerFunc { + return func(c *gin.Context) { + // 获取 authorization header + tokenString := c.GetHeader("Authorization") + + // 验证token + if tokenString == "" || !strings.HasPrefix(tokenString, "Bearer ") { + c.JSON(http.StatusOK, gin.H{ + "code": http.StatusUnauthorized, + "message": "请求未授权", + }) + c.Abort() + return + } + + tokenString = tokenString[7:] + + token, claims, err := common.ParseToken(tokenString) + if err != nil || !token.Valid { + c.JSON(http.StatusOK, gin.H{ + "code": http.StatusUnauthorized, + "message": "请求未授权", + }) + c.Abort() + return + } + // 验证通过后获取Claim中的UserId + userId := claims.UserId + DB := common.GetDB() + var user model.User + DB.First(&user, userId) + + // 用户不存在 + if user.ID == 0 { + c.JSON(http.StatusOK, gin.H{ + "code": http.StatusUnauthorized, + "message": "请求未授权", + }) + c.Abort() + return + } + + // 用户信息存在,则将用户信息存入上下文 + c.Set("user", user) + + c.Next() + } +} diff --git a/server/router/router.go b/server/router/router.go index 6fb1aba..605275e 100644 --- a/server/router/router.go +++ b/server/router/router.go @@ -1,10 +1,14 @@ package router -import "github.com/gin-gonic/gin" +import ( + "expenses/controller" + "expenses/middleware" + "github.com/gin-gonic/gin" +) func CollectRoute(r *gin.Engine) *gin.Engine { - //r.POST("/auth/register", controller.Register) - //r.POST("/auth/login", controller.Login) - //r.GET("/auth/info", middleware.AuthMiddleWare(), controller.Info) // 认证中间件保护info接口 + r.POST("/auth/register", controller.Register) + r.POST("/auth/login", controller.Login) + r.GET("/auth/info", middleware.AuthMiddleWare(), controller.Info) // 认证中间件保护info接口 return r } diff --git a/server/util/util.go b/server/util/util.go new file mode 100644 index 0000000..14c7da1 --- /dev/null +++ b/server/util/util.go @@ -0,0 +1,22 @@ +package util + +import ( + "math/rand" + "time" +) + +// RandomString +// +// @Description: 随机字符串 +// @param length +// @return string +func RandomString(length int) string { + str := "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + bytes := []byte(str) + var result []byte + r := rand.New(rand.NewSource(time.Now().UnixNano())) + for i := 0; i < length; i++ { + result = append(result, bytes[r.Intn(len(bytes))]) + } + return string(result) +}