plp-test/cmd/server/main.go

598 lines
15 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"net/http"
"os"
"os/signal"
"sync"
"syscall"
"time"
"plp-test/internal/config"
"plp-test/internal/model"
"plp-test/internal/testcase"
"plp-test/internal/utils"
"github.com/sirupsen/logrus"
)
var (
configFile string
logLevel string
)
func init() {
flag.StringVar(&configFile, "config", "config.yaml", "配置文件路径")
flag.StringVar(&logLevel, "log-level", "info", "日志级别 (debug, info, warn, error)")
}
// TestRunner 测试运行器
type TestRunner struct {
config *config.Config
logger *logrus.Logger
factory *testcase.TestCaseFactory
tests map[string]testcase.TestCase
testsMu sync.RWMutex
testResult map[string]*model.TestResult
resultMu sync.RWMutex
streams map[string]map[string]http.ResponseWriter
streamsMu sync.RWMutex
integrityInfo map[string]*model.IntegrityInfo
integrityMu sync.RWMutex
}
// NewTestRunner 创建测试运行器
func NewTestRunner(cfg *config.Config, logger *logrus.Logger) *TestRunner {
return &TestRunner{
config: cfg,
logger: logger,
factory: testcase.NewTestCaseFactory(cfg, logger),
tests: make(map[string]testcase.TestCase),
testResult: make(map[string]*model.TestResult),
streams: make(map[string]map[string]http.ResponseWriter),
integrityInfo: make(map[string]*model.IntegrityInfo),
}
}
// RunTest 运行指定的测试
func (r *TestRunner) RunTest(testType string) (*model.TestResult, error) {
r.logger.Infof("准备运行测试: %s", testType)
// 创建测试实例
test, err := r.factory.CreateTestCase(testType)
if err != nil {
return nil, fmt.Errorf("创建测试用例失败: %v", err)
}
if test == nil {
return nil, fmt.Errorf("未找到测试用例: %s", testType)
}
// 存储测试实例
testID := test.Status().TestID
r.testsMu.Lock()
r.tests[testID] = test
r.testsMu.Unlock()
// 创建上下文以便可以取消测试
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// 发送测试开始状态更新
r.sendStatusUpdate(test)
// 设置测试环境
r.logger.Info("设置测试环境")
if err := test.Setup(ctx, false); err != nil {
r.logger.Errorf("设置测试环境失败: %v", err)
r.sendErrorUpdate(testID, fmt.Sprintf("设置测试环境失败: %v", err))
return nil, err
}
// 启动状态监控协程
statusDone := make(chan struct{})
go func() {
defer close(statusDone)
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
// 发送状态更新
r.sendStatusUpdate(test)
}
}
}()
// 运行测试
r.logger.Info("运行测试")
result, err := test.Run(ctx)
if err != nil {
r.logger.Errorf("测试运行失败: %v", err)
r.sendErrorUpdate(testID, fmt.Sprintf("测试运行失败: %v", err))
// 尝试清理
cleanupErr := test.Cleanup(ctx)
if cleanupErr != nil {
r.logger.Errorf("测试清理失败: %v", cleanupErr)
}
return nil, err
}
// 清理测试环境
r.logger.Info("清理测试环境")
if err := test.Cleanup(ctx); err != nil {
r.logger.Errorf("测试清理失败: %v", err)
r.sendErrorUpdate(testID, fmt.Sprintf("测试清理失败: %v", err))
return nil, err
}
// 停止状态监控
cancel()
<-statusDone
// 存储测试结果
r.resultMu.Lock()
r.testResult[testID] = result
r.resultMu.Unlock()
// 移除测试实例
r.testsMu.Lock()
delete(r.tests, testID)
r.testsMu.Unlock()
// 发送完成通知
r.sendCompletionUpdate(testID, result)
r.logger.Infof("测试 %s 完成", testType)
return result, nil
}
// sendStatusUpdate 发送状态更新
func (r *TestRunner) sendStatusUpdate(test testcase.TestCase) {
status := test.Status()
update := model.StreamUpdate{
Type: "status",
TestID: status.TestID,
Timestamp: time.Now(),
Progress: status.Progress,
CurrentPhase: status.CurrentPhase,
Message: status.Message,
Data: status,
}
r.SendStreamUpdate(status.TestID, update)
}
// sendErrorUpdate 发送错误更新
func (r *TestRunner) sendErrorUpdate(testID, message string) {
update := model.StreamUpdate{
Type: "error",
TestID: testID,
Timestamp: time.Now(),
Message: message,
}
r.SendStreamUpdate(testID, update)
}
// sendCompletionUpdate 发送完成更新
func (r *TestRunner) sendCompletionUpdate(testID string, result *model.TestResult) {
update := model.StreamUpdate{
Type: "completion",
TestID: testID,
Timestamp: time.Now(),
Progress: 100,
Message: "测试完成",
Data: result,
}
r.SendStreamUpdate(testID, update)
}
// sendIntegrityUpdate 发送完整性更新
func (r *TestRunner) sendIntegrityUpdate(testID string, message string, info *model.IntegrityInfo) {
update := model.StreamUpdate{
Type: "integrity",
TestID: testID,
Timestamp: time.Now(),
Message: message,
Data: info,
}
r.SendStreamUpdate(testID, update)
}
// GetTestStatus 获取测试状态
func (r *TestRunner) GetTestStatus(testID string) *model.TestStatus {
r.testsMu.RLock()
defer r.testsMu.RUnlock()
if test, ok := r.tests[testID]; ok {
return test.Status()
}
return nil
}
// GetAllTestStatus 获取所有测试状态
func (r *TestRunner) GetAllTestStatus() []*model.TestStatus {
r.testsMu.RLock()
defer r.testsMu.RUnlock()
statuses := make([]*model.TestStatus, 0, len(r.tests))
for _, test := range r.tests {
statuses = append(statuses, test.Status())
}
return statuses
}
// RegisterStream 注册流式连接
func (r *TestRunner) RegisterStream(testID, clientID string, w http.ResponseWriter) {
r.streamsMu.Lock()
defer r.streamsMu.Unlock()
if _, ok := r.streams[testID]; !ok {
r.streams[testID] = make(map[string]http.ResponseWriter)
}
r.streams[testID][clientID] = w
r.logger.Infof("客户端 %s 已连接到测试 %s 的流", clientID, testID)
}
// UnregisterStream 注销流式连接
func (r *TestRunner) UnregisterStream(testID, clientID string) {
r.streamsMu.Lock()
defer r.streamsMu.Unlock()
if clients, ok := r.streams[testID]; ok {
delete(clients, clientID)
r.logger.Infof("客户端 %s 已断开与测试 %s 的流连接", clientID, testID)
}
}
// SendStreamUpdate 发送流式更新
func (r *TestRunner) SendStreamUpdate(testID string, update interface{}) {
r.streamsMu.RLock()
defer r.streamsMu.RUnlock()
clients, ok := r.streams[testID]
if !ok || len(clients) == 0 {
return
}
data, err := json.Marshal(update)
if err != nil {
r.logger.Errorf("无法序列化流更新: %v", err)
return
}
for clientID, w := range clients {
// 使用Server-Sent Events格式
_, err := fmt.Fprintf(w, "data: %s\n\n", data)
if err != nil {
r.logger.Warnf("向客户端 %s 发送更新失败: %v", clientID, err)
} else {
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
}
}
}
// SaveIntegrityInfo 保存完整性信息
func (r *TestRunner) SaveIntegrityInfo(testID string, info *model.IntegrityInfo) {
r.integrityMu.Lock()
defer r.integrityMu.Unlock()
r.integrityInfo[testID] = info
}
// GetIntegrityInfo 获取完整性信息
func (r *TestRunner) GetIntegrityInfo(testID string) *model.IntegrityInfo {
r.integrityMu.RLock()
defer r.integrityMu.RUnlock()
return r.integrityInfo[testID]
}
// StartServer 启动HTTP服务器
func StartServer(cfg *config.Config, runner *TestRunner, logger *logrus.Logger) *http.Server {
mux := http.NewServeMux()
// 健康检查接口
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
health := &model.HealthStatus{
Status: "ok",
Timestamp: time.Now(),
Message: "服务正常运行",
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(health)
})
// 运行测试接口
mux.HandleFunc("/run", func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req model.TestRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
logger.Infof("收到测试请求: %+v", req)
// 异步运行测试
go func() {
result, err := runner.RunTest(req.TestType)
if err != nil {
logger.Errorf("测试运行失败: %v", err)
} else {
logger.Infof("测试完成: %+v", result)
}
}()
resp := model.TestResponse{
RequestID: req.TestType + "-" + time.Now().Format("20060102-150405"),
Status: "accepted",
Message: "测试已接受并开始执行",
ServerTime: time.Now(),
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusAccepted)
json.NewEncoder(w).Encode(resp)
})
// 获取测试状态接口
mux.HandleFunc("/status", func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
testID := r.URL.Query().Get("test_id")
var status interface{}
if testID == "" {
// 获取所有测试状态
status = runner.GetAllTestStatus()
} else {
// 获取指定测试状态
status = runner.GetTestStatus(testID)
if status == nil {
http.Error(w, "Test not found", http.StatusNotFound)
return
}
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(status)
})
// 新增: 实时数据进度流式API
mux.HandleFunc("/stream", func(w http.ResponseWriter, r *http.Request) {
testID := r.URL.Query().Get("test_id")
if testID == "" {
http.Error(w, "Missing test_id", http.StatusBadRequest)
return
}
// 设置响应头支持SSE (Server-Sent Events)
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("Access-Control-Allow-Origin", "*")
// 创建完成通道
doneCh := make(chan struct{})
defer close(doneCh)
// 注册客户端连接
clientID := r.URL.Query().Get("client_id")
runner.RegisterStream(testID, clientID, w)
defer runner.UnregisterStream(testID, clientID)
// 保持连接直到客户端断开
select {
case <-r.Context().Done():
runner.logger.Infof("connection closed by client %s", clientID)
return
case <-doneCh:
runner.logger.Infof("connection closed by server for client %s", clientID)
return
}
})
// 新增: 数据完整性检测API
mux.HandleFunc("/integrity", func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
testID := r.URL.Query().Get("test_id")
if testID == "" {
http.Error(w, "Missing test_id", http.StatusBadRequest)
return
}
// 获取测试的数据完整性信息
integrityInfo := runner.GetIntegrityInfo(testID)
if integrityInfo == nil {
http.Error(w, "Integrity info not found", http.StatusNotFound)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(integrityInfo)
})
// 新增: 恢复测试API用于断电测试后的恢复与校验
mux.HandleFunc("/recovery", func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
TestType string `json:"test_type"`
TestDir string `json:"test_dir"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
logger.Infof("收到恢复测试请求: %+v", req)
// 创建恢复测试实例
test, err := runner.factory.CreateTestCase(req.TestType)
if err != nil || test == nil {
http.Error(w, fmt.Sprintf("无法创建测试实例: %v", err), http.StatusBadRequest)
return
}
// 获取测试ID
testID := test.Status().TestID
// 执行恢复和数据完整性检查
go func() {
ctx := context.Background()
// 设置测试环境
logger.Info("设置恢复测试环境")
if err := test.Setup(ctx, true); err != nil {
logger.Errorf("设置恢复测试环境失败: %v", err)
runner.sendErrorUpdate(testID, fmt.Sprintf("设置恢复测试环境失败: %v", err))
return
}
// 数据完整性检查
logger.Info("执行数据完整性检查")
runner.sendStatusUpdate(test)
// 检查并获取数据完整性信息
if powerTest, ok := test.(*testcase.PowerLossTest); ok {
integrityInfo := powerTest.CheckIntegrity()
go func() {
time.Sleep(1 * time.Second)
runner.sendIntegrityUpdate(testID, "开始数据完整性检查", nil)
}()
// 保存完整性信息
runner.SaveIntegrityInfo(testID, integrityInfo)
// 发送完整性信息 && BlocksMap 清理掉
integrityInfo.BlocksMap = nil
runner.sendIntegrityUpdate(testID, "数据完整性检查完成", integrityInfo)
logger.Infof("恢复测试完成: 没填充的数据: %.2f MB", integrityInfo.DataLossMB)
} else {
logger.Error("不是断电测试实例,无法执行数据完整性检查")
runner.sendErrorUpdate(testID, "不是断电测试实例,无法执行数据完整性检查")
}
// 清理测试环境
logger.Info("清理恢复测试环境")
if err := test.Cleanup(ctx); err != nil {
logger.Errorf("清理恢复测试环境失败: %v", err)
}
}()
// 返回接受响应
resp := model.TestResponse{
RequestID: testID,
Status: "accepted",
Message: "恢复测试已接受并开始执行",
ServerTime: time.Now(),
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusAccepted)
json.NewEncoder(w).Encode(resp)
})
// 启动服务器
addr := fmt.Sprintf("%s:%d", cfg.Server.ListenAddr, cfg.Server.Port)
server := &http.Server{
Addr: addr,
Handler: mux,
}
go func() {
logger.Infof("服务器启动在 %s", addr)
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
logger.Fatalf("服务器启动失败: %v", err)
}
}()
return server
}
func main() {
flag.Parse()
// 初始化日志级别
var level logrus.Level
switch logLevel {
case "debug":
level = logrus.DebugLevel
case "info":
level = logrus.InfoLevel
case "warn":
level = logrus.WarnLevel
case "error":
level = logrus.ErrorLevel
default:
level = logrus.InfoLevel
}
// 初始化日志
logger := logrus.New()
logger.SetLevel(level)
logger.SetFormatter(&logrus.TextFormatter{
FullTimestamp: true,
TimestampFormat: "2006-01-02 15:04:05",
})
// 加载配置
logger.Infof("加载配置文件: %s", configFile)
cfg, err := config.Load(configFile)
if err != nil {
logger.Fatalf("加载配置失败: %v", err)
}
// 初始化日志文件
if cfg.Server.LogFile != "" {
utils.InitLogger(cfg.Server.LogFile, level)
logger = utils.Logger
}
// 创建测试运行器
runner := NewTestRunner(cfg, logger)
// 启动服务器
server := StartServer(cfg, runner, logger)
// 等待终止信号
stop := make(chan os.Signal, 1)
signal.Notify(stop, os.Interrupt, syscall.SIGTERM)
<-stop
logger.Info("正在关闭服务器...")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := server.Shutdown(ctx); err != nil {
logger.Fatalf("服务器强制关闭: %v", err)
}
logger.Info("服务器已优雅关闭")
}