feat: add login func

main
zggsong 2 years ago
parent 8f82600347
commit 761ad266f4

@ -6,6 +6,8 @@ import (
"os" "os"
) )
var Config model.Config
func ConfigInit() (model.Config, error) { func ConfigInit() (model.Config, error) {
var config model.Config var config model.Config
@ -27,3 +29,7 @@ func ConfigInit() (model.Config, error) {
return config, nil return config, nil
} }
func GetConfig() model.Config {
return Config
}

@ -8,6 +8,8 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
) )
var DB *gorm.DB
func InitDB(conf model.DataSource) (*gorm.DB, error) { func InitDB(conf model.DataSource) (*gorm.DB, error) {
// 拼接下 dsn 参数, dsn 格式可以参考上面的语法 // 拼接下 dsn 参数, dsn 格式可以参考上面的语法
//这里使用 Sprintf 动态拼接 dsn 参数,因为一般数据库连接参数,我们都是保存在配置文件里面,需要从配置文件加载参数,然后拼接 dsn。 //这里使用 Sprintf 动态拼接 dsn 参数,因为一般数据库连接参数,我们都是保存在配置文件里面,需要从配置文件加载参数,然后拼接 dsn。
@ -29,5 +31,10 @@ func InitDB(conf model.DataSource) (*gorm.DB, error) {
return nil, err return nil, err
} }
DB = db
return db, nil return db, nil
} }
func GetDB() *gorm.DB {
return DB
}

@ -0,0 +1,44 @@
package common
import (
"expenses/model"
"github.com/dgrijalva/jwt-go"
"time"
)
var jwtKey = []byte("a_secret_key")
type Claims struct {
UserId uint
jwt.StandardClaims
}
func ReleaseToken(user model.User) (string, error) {
// 设置Token过期时间
expirationTime := time.Now().Add(7 * 24 * time.Hour)
claims := Claims{
UserId: user.ID,
StandardClaims: jwt.StandardClaims{
ExpiresAt: expirationTime.Unix(),
IssuedAt: time.Now().Unix(),
Issuer: "github.com/zggsong/expenses",
Subject: "user token",
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString(jwtKey)
if err != nil {
return "", err
}
return tokenString, nil
}
func ParseToken(tokenString string) (*jwt.Token, *Claims, error) {
claims := &Claims{}
token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
return jwtKey, nil
})
return token, claims, err
}

@ -0,0 +1,23 @@
package common
import (
"net/http"
"github.com/gin-gonic/gin"
)
func Response(ctx *gin.Context, httpStatus int, code int, data gin.H, msg string) {
ctx.JSON(httpStatus, gin.H{
"code": code,
"data": data,
"msg": msg,
})
}
func Success(ctx *gin.Context, data gin.H, msg string) {
Response(ctx, http.StatusOK, 200, data, msg)
}
func Fail(ctx *gin.Context, data gin.H, msg string) {
Response(ctx, http.StatusOK, 400, data, msg)
}

@ -0,0 +1,13 @@
package controller
import (
"expenses/common"
"expenses/dto"
"expenses/model"
"github.com/gin-gonic/gin"
)
func Info(ctx *gin.Context) {
user, _ := ctx.Get("user")
common.Success(ctx, gin.H{"user": dto.ToUserDto(user.(model.User))}, "")
}

@ -0,0 +1,50 @@
package controller
import (
"expenses/common"
"expenses/model"
"github.com/gin-gonic/gin"
"golang.org/x/crypto/bcrypt"
"net/http"
)
func Login(ctx *gin.Context) {
DB := common.GetDB()
// 获取参数
telephone := ctx.Query("telephone")
password := ctx.Query("password")
// 数据验证
if len(telephone) != 11 {
common.Fail(ctx, nil, "手机号必须是11位")
return
}
if len(password) < 6 {
common.Fail(ctx, nil, "密码必须大于6位")
return
}
// 判断手机号是否存在
var user model.User
DB.Where("telephone = ?", telephone).First(&user)
if user.ID == 0 {
common.Fail(ctx, nil, "用户不存在")
return
}
// 密码验证
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil {
common.Fail(ctx, nil, "密码错误")
return
}
// 发放Token
token, err := common.ReleaseToken(user)
if err != nil {
common.Response(ctx, http.StatusInternalServerError, 500, nil, "发放Token失败")
return
}
// 返回结果
common.Success(ctx, gin.H{"token": token}, "登录成功")
}

@ -0,0 +1,61 @@
package controller
import (
"expenses/common"
"expenses/model"
"expenses/util"
"github.com/gin-gonic/gin"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
"net/http"
)
func Register(ctx *gin.Context) {
DB := common.GetDB()
// 获取参数
name := ctx.Query("name")
telephone := ctx.Query("telephone")
password := ctx.Query("password")
// 数据验证
if len(telephone) != 11 {
common.Fail(ctx, nil, "手机号必须是11位")
return
}
if len(password) < 6 {
common.Fail(ctx, nil, "密码必须大于6位")
return
}
// 如果名字为空则返回10为随机字符串
if name == "" {
name = util.RandomString(10)
}
// 判断手机号是否存在
if isExistTelephone(DB, telephone) {
common.Fail(ctx, nil, "用户已存在")
return
}
// 创建用户
hasedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
common.Response(ctx, http.StatusInternalServerError, 500, nil, "密码加密失败")
return
}
newUser := model.User{
Name: name,
Telephone: telephone,
Password: string(hasedPassword),
}
DB.Create(&newUser)
// 返回结果
common.Success(ctx, nil, "注册成功")
}
func isExistTelephone(db *gorm.DB, telephone string) bool {
var user model.User
db.Where("telephone = ?", telephone).First(&user)
return user.ID != 0
}

@ -0,0 +1,15 @@
package dto
import "expenses/model"
type UserDto struct {
Name string `json:"name"`
Telephone string `json:"telephone"`
}
func ToUserDto(user model.User) UserDto {
return UserDto{
Name: user.Name,
Telephone: user.Telephone,
}
}

@ -1,11 +0,0 @@
package global
import (
"expenses/model"
"gorm.io/gorm"
)
var (
GLO_CONF model.Config
GLO_DB *gorm.DB
)

@ -3,8 +3,10 @@ module expenses
go 1.19 go 1.19
require ( require (
github.com/dgrijalva/jwt-go v3.2.0+incompatible
github.com/gin-gonic/gin v1.8.2 github.com/gin-gonic/gin v1.8.2
github.com/spf13/viper v1.15.0 github.com/spf13/viper v1.15.0
golang.org/x/crypto v0.4.0
gorm.io/driver/postgres v1.4.7 gorm.io/driver/postgres v1.4.7
gorm.io/gorm v1.24.5 gorm.io/gorm v1.24.5
) )
@ -36,7 +38,6 @@ require (
github.com/spf13/pflag v1.0.5 // indirect github.com/spf13/pflag v1.0.5 // indirect
github.com/subosito/gotenv v1.4.2 // indirect github.com/subosito/gotenv v1.4.2 // indirect
github.com/ugorji/go/codec v1.2.7 // indirect github.com/ugorji/go/codec v1.2.7 // indirect
golang.org/x/crypto v0.4.0 // indirect
golang.org/x/net v0.4.0 // indirect golang.org/x/net v0.4.0 // indirect
golang.org/x/sys v0.3.0 // indirect golang.org/x/sys v0.3.0 // indirect
golang.org/x/text v0.5.0 // indirect golang.org/x/text v0.5.0 // indirect

@ -50,6 +50,8 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM=
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=

@ -2,7 +2,6 @@ package main
import ( import (
"expenses/common" "expenses/common"
"expenses/global"
"expenses/router" "expenses/router"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"os" "os"
@ -20,21 +19,17 @@ func main() {
} }
/*数据库初始化*/ /*数据库初始化*/
initDB, err := common.InitDB(conf.DataSource) _, err = common.InitDB(conf.DataSource)
if err != nil { if err != nil {
common.Log.Printf("数据库初始化失败!") common.Log.Printf("数据库初始化失败!")
os.Exit(-2) os.Exit(-2)
} }
/*全局变量初始化*/
global.GLO_CONF = conf
global.GLO_DB = initDB
/*Gin初始化*/ /*Gin初始化*/
//gin.SetMode(gin.ReleaseMode) //gin.SetMode(gin.ReleaseMode)
r := gin.Default() r := gin.Default()
r = router.CollectRoute(r) r = router.CollectRoute(r)
port := global.GLO_CONF.Server.Port port := conf.Server.Port
if port == "" { if port == "" {
port = "8080" port = "8080"
} }

@ -0,0 +1,63 @@
package middleware
import (
"expenses/common"
"expenses/model"
"net/http"
"strings"
"github.com/gin-gonic/gin"
)
// AuthMiddleWare
//
// @Description: 认证中间件
// @return gin.HandlerFunc
func AuthMiddleWare() gin.HandlerFunc {
return func(c *gin.Context) {
// 获取 authorization header
tokenString := c.GetHeader("Authorization")
// 验证token
if tokenString == "" || !strings.HasPrefix(tokenString, "Bearer ") {
c.JSON(http.StatusOK, gin.H{
"code": http.StatusUnauthorized,
"message": "请求未授权",
})
c.Abort()
return
}
tokenString = tokenString[7:]
token, claims, err := common.ParseToken(tokenString)
if err != nil || !token.Valid {
c.JSON(http.StatusOK, gin.H{
"code": http.StatusUnauthorized,
"message": "请求未授权",
})
c.Abort()
return
}
// 验证通过后获取Claim中的UserId
userId := claims.UserId
DB := common.GetDB()
var user model.User
DB.First(&user, userId)
// 用户不存在
if user.ID == 0 {
c.JSON(http.StatusOK, gin.H{
"code": http.StatusUnauthorized,
"message": "请求未授权",
})
c.Abort()
return
}
// 用户信息存在,则将用户信息存入上下文
c.Set("user", user)
c.Next()
}
}

@ -1,10 +1,14 @@
package router package router
import "github.com/gin-gonic/gin" import (
"expenses/controller"
"expenses/middleware"
"github.com/gin-gonic/gin"
)
func CollectRoute(r *gin.Engine) *gin.Engine { func CollectRoute(r *gin.Engine) *gin.Engine {
//r.POST("/auth/register", controller.Register) r.POST("/auth/register", controller.Register)
//r.POST("/auth/login", controller.Login) r.POST("/auth/login", controller.Login)
//r.GET("/auth/info", middleware.AuthMiddleWare(), controller.Info) // 认证中间件保护info接口 r.GET("/auth/info", middleware.AuthMiddleWare(), controller.Info) // 认证中间件保护info接口
return r return r
} }

@ -0,0 +1,22 @@
package util
import (
"math/rand"
"time"
)
// RandomString
//
// @Description: 随机字符串
// @param length
// @return string
func RandomString(length int) string {
str := "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
bytes := []byte(str)
var result []byte
r := rand.New(rand.NewSource(time.Now().UnixNano()))
for i := 0; i < length; i++ {
result = append(result, bytes[r.Intn(len(bytes))])
}
return string(result)
}
Loading…
Cancel
Save