leveldb存储token的简单实现
在之前的文章中leveldb的grpc接口没有实现ttl,这期加上。
一、改造leveldb的grpc接口支持ttl
这里简单改造,value值多存一个时间值,查的时候比较这个值即可
leveldb.proto
syntax = "proto3";
option go_package = "../src/leveldb;leveldb";
package leveldb;
service LevelDBService {
rpc Put(PutRequest) returns (PutResponse) {}
rpc Get(GetRequest) returns (GetResponse) {}
rpc Has(HasRequest) returns (HasResponse) {}
rpc Delete(DeleteRequest) returns (DeleteResponse) {}
rpc GetTTL(GetTTLRequest) returns (GetTTLResponse) {} // 新增方法
}
message PutRequest {
string key = 1;
string value = 2;
int64 ttl = 3; // 单位:秒
}
message PutResponse {}
message GetRequest {
string key = 1;
}
message GetResponse {
string value = 1;
}
message HasRequest {
string key = 1;
}
message HasResponse {
bool exists = 1;
}
message DeleteRequest {
string key = 1;
}
message DeleteResponse {}
message GetTTLRequest {
string key = 1;
}
message GetTTLResponse {
int64 ttl = 1; // 单位:秒
}
main.go
package main
import (
"context"
"encoding/binary"
"log"
"net"
"time"
pb "gin/src/leveldb" // 替换为实际路径
leveldb "github.com/syndtr/goleveldb/leveldb"
"google.golang.org/grpc"
"google.golang.org/grpc/reflection"
)
// server 结构体嵌入 UnimplementedLevelDBServiceServer
type server struct {
pb.UnimplementedLevelDBServiceServer
db *leveldb.DB
}
func (s *server) Put(ctx context.Context, req *pb.PutRequest) (*pb.PutResponse, error) {
key := []byte(req.GetKey())
value := []byte(req.GetValue())
// 处理 TTL
ttl := time.Duration(req.GetTtl()) * time.Second
if ttl > 0 {
// 设置过期时间
expirationTime := time.Now().Add(ttl).UnixNano()
expirationData := make([]byte, 8)
binary.BigEndian.PutUint64(expirationData, uint64(expirationTime))
value = append(value, expirationData...)
}
err := s.db.Put(key, value, nil)
if err != nil {
return nil, err
}
return &pb.PutResponse{}, nil
}
func (s *server) Get(ctx context.Context, req *pb.GetRequest) (*pb.GetResponse, error) {
key := []byte(req.GetKey())
value, err := s.db.Get(key, nil)
if err == leveldb.ErrNotFound {
return &pb.GetResponse{Value: ""}, nil
} else if err != nil {
return nil, err
}
// 检查过期时间
if len(value) > 8 {
expirationTime := int64(binary.BigEndian.Uint64(value[len(value)-8:]))
if time.Now().UnixNano() > expirationTime {
return &pb.GetResponse{Value: ""}, nil
}
value = value[:len(value)-8]
}
return &pb.GetResponse{Value: string(value)}, nil
}
func (s *server) Has(ctx context.Context, req *pb.HasRequest) (*pb.HasResponse, error) {
key := []byte(req.GetKey())
value, err := s.db.Get(key, nil)
if err == leveldb.ErrNotFound {
return &pb.HasResponse{Exists: false}, nil
} else if err != nil {
return nil, err
}
// 检查过期时间
if len(value) > 8 {
expirationTime := int64(binary.BigEndian.Uint64(value[len(value)-8:]))
if time.Now().UnixNano() > expirationTime {
return &pb.HasResponse{Exists: false}, nil
}
}
return &pb.HasResponse{Exists: true}, nil
}
func (s *server) Delete(ctx context.Context, req *pb.DeleteRequest) (*pb.DeleteResponse, error) {
key := []byte(req.GetKey())
err := s.db.Delete(key, nil)
if err != nil {
return nil, err
}
return &pb.DeleteResponse{}, nil
}
func (s *server) GetTTL(ctx context.Context, req *pb.GetTTLRequest) (*pb.GetTTLResponse, error) {
key := []byte(req.GetKey())
value, err := s.db.Get(key, nil)
if err == leveldb.ErrNotFound {
return &pb.GetTTLResponse{Ttl: -1}, nil // 返回 -1 表示键不存在
} else if err != nil {
return nil, err
}
// 检查过期时间
if len(value) > 8 {
expirationTime := int64(binary.BigEndian.Uint64(value[len(value)-8:]))
currentTime := time.Now().UnixNano()
if currentTime > expirationTime {
return &pb.GetTTLResponse{Ttl: 0}, nil // 返回 0 表示已过期
}
remainingTTL := (expirationTime - currentTime) / int64(time.Second)
return &pb.GetTTLResponse{Ttl: remainingTTL}, nil
}
return &pb.GetTTLResponse{Ttl: -1}, nil // 返回 -1 表示没有设置 TTL
}
func main() {
// 打开或创建一个新的LevelDB数据库
dbPath := "./data/leveldb/"
db, err := leveldb.OpenFile(dbPath, nil)
if err != nil {
log.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
lis, err := net.Listen("tcp", ":3051")
if err != nil {
log.Fatalf("failed to listen: %v", err)
}
s := grpc.NewServer()
pb.RegisterLevelDBServiceServer(s, &server{db: db})
// 启用服务器反射
reflection.Register(s)
// 在启动时输出一条日志信息
log.Println("gRPC server started on :3051")
if err := s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}
其他代码不变。
二、实现token认证工具类
leveldbTokenOperator.go
package middleware
import "time"
import (
"context"
"encoding/json"
"gin-epg/src/leveldb" // 假设 gRPC 服务的 proto 文件已经编译并生成了 Go 代码
"github.com/google/uuid"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
const (
TokenPrefixKey = "token:"
UserPrefixKey = "user:"
SingleUserMaxTokenSize = 10
defaultExpirationTime = 1209600 // 2 weeks in seconds
)
type LeveldbSimpleTokenOperator struct {
expirationTimeInSecond int64
client leveldb.LevelDBServiceClient
}
func NewLeveldbSimpleTokenOperator(addr string, expirationTimeInSecond ...int64) (*LeveldbSimpleTokenOperator, error) {
conn, err := grpc.Dial(addr, grpc.WithInsecure())
if err != nil {
return nil, err
}
expiration := defaultExpirationTime
if len(expirationTimeInSecond) > 0 {
expiration = int(expirationTimeInSecond[0])
}
client := leveldb.NewLevelDBServiceClient(conn)
return &LeveldbSimpleTokenOperator{
expirationTimeInSecond: int64(expiration),
client: client,
}, nil
}
// getUserMapFromToken 从 token 中获取 UserMap
func (r *LeveldbSimpleTokenOperator) getUserMapFromToken(token string) (map[string]interface{}, error) {
key := TokenPrefixKey + token
resp, err := r.client.Get(context.Background(), &leveldb.GetRequest{Key: key})
if err != nil {
if status.Code(err) == codes.NotFound {
return nil, nil
}
return nil, err
}
var userMap map[string]interface{}
err = json.Unmarshal([]byte(resp.Value), &userMap)
if err != nil {
return nil, err
}
return userMap, nil
}
// getExpirationDateFromToken 从 token 中获取过期日
func (r *LeveldbSimpleTokenOperator) getExpirationDateFromToken(token string) (time.Time, error) {
resp, err := r.client.GetTTL(context.Background(), &leveldb.GetTTLRequest{Key: token})
if err != nil {
return time.Time{}, err
}
return time.Now().Add(time.Duration(resp.Ttl) * time.Second), nil
}
// isTokenExpired 判断 token 是否过期
func (r *LeveldbSimpleTokenOperator) isTokenExpired(token string) (bool, error) {
key := TokenPrefixKey + token
resp, err := r.client.Has(context.Background(), &leveldb.HasRequest{Key: key})
if err != nil {
return false, err
}
return !resp.Exists, nil
}
// generateToken 生成 token
func (r *LeveldbSimpleTokenOperator) generateToken(userMap map[string]interface{}) (string, error) {
token := uuid.New().String()
key := TokenPrefixKey + token
value, err := json.Marshal(userMap)
if err != nil {
return "", err
}
_, err = r.client.Put(context.Background(), &leveldb.PutRequest{Key: key, Value: string(value), Ttl: r.expirationTimeInSecond})
if err != nil {
return "", err
}
return token, nil
}
// validateToken 判断 token 是否有效
func (r *LeveldbSimpleTokenOperator) validateToken(token string) (bool, error) {
isExpired, err := r.isTokenExpired(token)
if err != nil {
return false, err
}
return !isExpired, nil
}
gin中间件封装
middlewareUtil.go
package middleware
import (
"fmt"
"gin-epg/src/common/util"
"github.com/gin-gonic/gin"
"net/http"
"strings"
)
// TokenMiddleware 是一个用于校验 token 的中间件
func TokenMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var token string
// 尝试从 Authorization 请求头获取 token
authHeader := c.GetHeader("Authorization")
if authHeader != "" {
// 按空格分割
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) == 2 && parts[0] == "Bearer" {
token = parts[1]
} else {
c.JSON(http.StatusBadRequest, gin.H{
"code": 2004,
"msg": "请求头中auth格式有误",
})
c.Abort()
return
}
} else {
// 如果 Authorization 头不存在,尝试从 token 请求头获取 token
token = c.GetHeader("token")
if token == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "未提供 token"})
c.Abort()
return
}
}
// 使用解析JWT的函数来解析 token
claims, err := ParseToken(token)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{
"code": 2005,
"msg": "无效的Token",
})
c.Abort()
return
}
// 将用户信息设置到请求上下文中
c.Set("id", claims["id"])
c.Set("username", claims["username"])
c.Set("role", claims["role"])
c.Set("avatarUrl", claims["avatarUrl"])
c.Next() // 后续的处理函数可以通过c.Get("username")等来获取当前请求的用户信息
}
}
// ParseToken 解析token并返回用户信息
func ParseToken(s string) (map[string]interface{}, error) {
// 拼接 fileDownloadUrl 和 filePath
addr := "localhost:3051"
configAddr, err := util.GetConfigValue("leveldbRpcUrl")
if err == nil {
addr = configAddr.(string) // 进行类型断言
}
operator, err := NewLeveldbSimpleTokenOperator(addr)
if err != nil {
return nil, err
}
claims, err := operator.getUserMapFromToken(s)
if err != nil {
return nil, err
}
return claims, nil
}
// GenerateToken 生成Token
func GenerateToken(userMap map[string]interface{}) (string, error) {
// 获取配置中的地址
addr := "localhost:3051"
configAddr, err := util.GetConfigValue("leveldbRpcUrl")
if err == nil {
addr = configAddr.(string) // 进行类型断言
}
// 创建LevelDB Token操作器
operator, err := NewLeveldbSimpleTokenOperator(addr)
if err != nil {
return "", fmt.Errorf("failed to create LevelDB token operator: %w", err)
}
// 生成Token
claims, err := operator.generateToken(userMap)
if err != nil {
return "", fmt.Errorf("failed to generate token: %w", err)
}
return claims, nil
}
登录时生成token调用示例
package controller
import (
"crypto/md5"
"fmt"
"gin-epg/src/entity"
"gin-epg/src/middleware"
"gin-epg/src/service"
"github.com/gin-gonic/gin"
"net/http"
)
// RestResponse 响应结构体
type RestResponse struct {
Status int `json:"status"`
Result interface{} `json:"result,omitempty"`
Message string `json:"message,omitempty"`
}
// LoginData 登录数据结构体
type LoginData struct {
Username string `json:"username"`
Password string `json:"password"`
}
// UserInfo 用户信息结构体
type UserInfo struct {
Token string `json:"token"`
User *entity.User `json:"user"`
}
// doLogin 登录处理函数
func DoLogin(c *gin.Context) {
var loginData LoginData
if err := c.ShouldBindJSON(&loginData); err != nil {
c.JSON(http.StatusBadRequest, RestResponse{Status: 400, Message: "请求参数错误"})
return
}
username := loginData.Username
password := loginData.Password
loginUser, err := service.FindUserByName(username)
if err != nil || loginUser == nil {
c.JSON(http.StatusUnauthorized, RestResponse{Status: 401, Message: "用户不存在"})
return
}
dbPassword := loginUser.Password
encryptedText := MD5(password)
if dbPassword != encryptedText {
c.JSON(http.StatusUnauthorized, RestResponse{Status: 401, Message: "登录失败"})
return
}
userInfoClaims := map[string]interface{}{
"id": loginUser.ID,
"username": loginUser.Username,
"role": "user",
"avatarUrl": loginUser.AvatarURL,
}
token, err := middleware.GenerateToken(userInfoClaims)
if err != nil {
c.JSON(http.StatusInternalServerError, RestResponse{Status: 500, Message: "生成Token失败"})
return
}
userInfo := UserInfo{
Token: token,
User: loginUser,
}
c.JSON(http.StatusOK, RestResponse{Status: 200, Result: userInfo})
}
func MD5(str string) string {
data := []byte(str) //切片
has := md5.Sum(data)
md5str := fmt.Sprintf("%x", has) //将[]byte转成16进制
return md5str
}
调用接口校验token示例
单个接口使用
epgChannelGroup.GET("/deleteChannelByName", middleware.TokenMiddleware(), controller.DeleteEpgChannelByName)
group使用
epgChannelGroup.Use(middleware.TokenMiddleware())