package main import ( "context" "encoding/json" "flag" "fmt" "net/http" "os" "os/signal" "sync" "syscall" "time" "plp-test/internal/config" "plp-test/internal/model" "plp-test/internal/testcase" "plp-test/internal/utils" "github.com/sirupsen/logrus" ) var ( configFile string logLevel string ) func init() { flag.StringVar(&configFile, "config", "config.yaml", "配置文件路径") flag.StringVar(&logLevel, "log-level", "info", "日志级别 (debug, info, warn, error)") } // TestRunner 测试运行器 type TestRunner struct { config *config.Config logger *logrus.Logger factory *testcase.TestCaseFactory tests map[string]testcase.TestCase testsMu sync.RWMutex testResult map[string]*model.TestResult resultMu sync.RWMutex streams map[string]map[string]http.ResponseWriter streamsMu sync.RWMutex integrityInfo map[string]*model.IntegrityInfo integrityMu sync.RWMutex } // NewTestRunner 创建测试运行器 func NewTestRunner(cfg *config.Config, logger *logrus.Logger) *TestRunner { return &TestRunner{ config: cfg, logger: logger, factory: testcase.NewTestCaseFactory(cfg, logger), tests: make(map[string]testcase.TestCase), testResult: make(map[string]*model.TestResult), streams: make(map[string]map[string]http.ResponseWriter), integrityInfo: make(map[string]*model.IntegrityInfo), } } // RunTest 运行指定的测试 func (r *TestRunner) RunTest(testType string) (*model.TestResult, error) { r.logger.Infof("准备运行测试: %s", testType) // 创建测试实例 test, err := r.factory.CreateTestCase(testType) if err != nil { return nil, fmt.Errorf("创建测试用例失败: %v", err) } if test == nil { return nil, fmt.Errorf("未找到测试用例: %s", testType) } // 存储测试实例 testID := test.Status().TestID r.testsMu.Lock() r.tests[testID] = test r.testsMu.Unlock() // 创建上下文以便可以取消测试 ctx, cancel := context.WithCancel(context.Background()) defer cancel() // 发送测试开始状态更新 r.sendStatusUpdate(test) // 设置测试环境 r.logger.Info("设置测试环境") if err := test.Setup(ctx, false); err != nil { r.logger.Errorf("设置测试环境失败: %v", err) r.sendErrorUpdate(testID, fmt.Sprintf("设置测试环境失败: %v", err)) return nil, err } // 启动状态监控协程 statusDone := make(chan struct{}) go func() { defer close(statusDone) ticker := time.NewTicker(200 * time.Millisecond) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-ticker.C: // 发送状态更新 r.sendStatusUpdate(test) } } }() // 运行测试 r.logger.Info("运行测试") result, err := test.Run(ctx) if err != nil { r.logger.Errorf("测试运行失败: %v", err) r.sendErrorUpdate(testID, fmt.Sprintf("测试运行失败: %v", err)) // 尝试清理 cleanupErr := test.Cleanup(ctx) if cleanupErr != nil { r.logger.Errorf("测试清理失败: %v", cleanupErr) } return nil, err } // 清理测试环境 r.logger.Info("清理测试环境") if err := test.Cleanup(ctx); err != nil { r.logger.Errorf("测试清理失败: %v", err) r.sendErrorUpdate(testID, fmt.Sprintf("测试清理失败: %v", err)) return nil, err } // 停止状态监控 cancel() <-statusDone // 存储测试结果 r.resultMu.Lock() r.testResult[testID] = result r.resultMu.Unlock() // 移除测试实例 r.testsMu.Lock() delete(r.tests, testID) r.testsMu.Unlock() // 发送完成通知 r.sendCompletionUpdate(testID, result) r.logger.Infof("测试 %s 完成", testType) return result, nil } // sendStatusUpdate 发送状态更新 func (r *TestRunner) sendStatusUpdate(test testcase.TestCase) { status := test.Status() update := model.StreamUpdate{ Type: "status", TestID: status.TestID, Timestamp: time.Now(), Progress: status.Progress, CurrentPhase: status.CurrentPhase, Message: status.Message, Data: status, } r.SendStreamUpdate(status.TestID, update) } // sendErrorUpdate 发送错误更新 func (r *TestRunner) sendErrorUpdate(testID, message string) { update := model.StreamUpdate{ Type: "error", TestID: testID, Timestamp: time.Now(), Message: message, } r.SendStreamUpdate(testID, update) } // sendCompletionUpdate 发送完成更新 func (r *TestRunner) sendCompletionUpdate(testID string, result *model.TestResult) { update := model.StreamUpdate{ Type: "completion", TestID: testID, Timestamp: time.Now(), Progress: 100, Message: "测试完成", Data: result, } r.SendStreamUpdate(testID, update) } // sendIntegrityUpdate 发送完整性更新 func (r *TestRunner) sendIntegrityUpdate(testID string, message string, info *model.IntegrityInfo) { update := model.StreamUpdate{ Type: "integrity", TestID: testID, Timestamp: time.Now(), Message: message, Data: info, } r.SendStreamUpdate(testID, update) } // GetTestStatus 获取测试状态 func (r *TestRunner) GetTestStatus(testID string) *model.TestStatus { r.testsMu.RLock() defer r.testsMu.RUnlock() if test, ok := r.tests[testID]; ok { return test.Status() } return nil } // GetAllTestStatus 获取所有测试状态 func (r *TestRunner) GetAllTestStatus() []*model.TestStatus { r.testsMu.RLock() defer r.testsMu.RUnlock() statuses := make([]*model.TestStatus, 0, len(r.tests)) for _, test := range r.tests { statuses = append(statuses, test.Status()) } return statuses } // RegisterStream 注册流式连接 func (r *TestRunner) RegisterStream(testID, clientID string, w http.ResponseWriter) { r.streamsMu.Lock() defer r.streamsMu.Unlock() if _, ok := r.streams[testID]; !ok { r.streams[testID] = make(map[string]http.ResponseWriter) } r.streams[testID][clientID] = w r.logger.Infof("客户端 %s 已连接到测试 %s 的流", clientID, testID) } // UnregisterStream 注销流式连接 func (r *TestRunner) UnregisterStream(testID, clientID string) { r.streamsMu.Lock() defer r.streamsMu.Unlock() if clients, ok := r.streams[testID]; ok { delete(clients, clientID) r.logger.Infof("客户端 %s 已断开与测试 %s 的流连接", clientID, testID) } } // SendStreamUpdate 发送流式更新 func (r *TestRunner) SendStreamUpdate(testID string, update interface{}) { r.streamsMu.RLock() defer r.streamsMu.RUnlock() clients, ok := r.streams[testID] if !ok || len(clients) == 0 { return } data, err := json.Marshal(update) if err != nil { r.logger.Errorf("无法序列化流更新: %v", err) return } for clientID, w := range clients { // 使用Server-Sent Events格式 _, err := fmt.Fprintf(w, "data: %s\n\n", data) if err != nil { r.logger.Warnf("向客户端 %s 发送更新失败: %v", clientID, err) } else { if f, ok := w.(http.Flusher); ok { f.Flush() } } } } // SaveIntegrityInfo 保存完整性信息 func (r *TestRunner) SaveIntegrityInfo(testID string, info *model.IntegrityInfo) { r.integrityMu.Lock() defer r.integrityMu.Unlock() r.integrityInfo[testID] = info } // GetIntegrityInfo 获取完整性信息 func (r *TestRunner) GetIntegrityInfo(testID string) *model.IntegrityInfo { r.integrityMu.RLock() defer r.integrityMu.RUnlock() return r.integrityInfo[testID] } // StartServer 启动HTTP服务器 func StartServer(cfg *config.Config, runner *TestRunner, logger *logrus.Logger) *http.Server { mux := http.NewServeMux() // 健康检查接口 mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { health := &model.HealthStatus{ Status: "ok", Timestamp: time.Now(), Message: "服务正常运行", } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(health) }) // 运行测试接口 mux.HandleFunc("/run", func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } var req model.TestRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return } logger.Infof("收到测试请求: %+v", req) // 异步运行测试 go func() { result, err := runner.RunTest(req.TestType) if err != nil { logger.Errorf("测试运行失败: %v", err) } else { logger.Infof("测试完成: %+v", result) } }() resp := model.TestResponse{ RequestID: req.TestType + "-" + time.Now().Format("20060102-150405"), Status: "accepted", Message: "测试已接受并开始执行", ServerTime: time.Now(), } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusAccepted) json.NewEncoder(w).Encode(resp) }) // 获取测试状态接口 mux.HandleFunc("/status", func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } testID := r.URL.Query().Get("test_id") var status interface{} if testID == "" { // 获取所有测试状态 status = runner.GetAllTestStatus() } else { // 获取指定测试状态 status = runner.GetTestStatus(testID) if status == nil { http.Error(w, "Test not found", http.StatusNotFound) return } } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(status) }) // 新增: 实时数据进度流式API mux.HandleFunc("/stream", func(w http.ResponseWriter, r *http.Request) { testID := r.URL.Query().Get("test_id") if testID == "" { http.Error(w, "Missing test_id", http.StatusBadRequest) return } // 设置响应头,支持SSE (Server-Sent Events) w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") w.Header().Set("Access-Control-Allow-Origin", "*") // 创建完成通道 doneCh := make(chan struct{}) defer close(doneCh) // 注册客户端连接 clientID := r.URL.Query().Get("client_id") runner.RegisterStream(testID, clientID, w) defer runner.UnregisterStream(testID, clientID) // 保持连接直到客户端断开 select { case <-r.Context().Done(): runner.logger.Infof("connection closed by client %s", clientID) return case <-doneCh: runner.logger.Infof("connection closed by server for client %s", clientID) return } }) // 新增: 数据完整性检测API mux.HandleFunc("/integrity", func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } testID := r.URL.Query().Get("test_id") if testID == "" { http.Error(w, "Missing test_id", http.StatusBadRequest) return } // 获取测试的数据完整性信息 integrityInfo := runner.GetIntegrityInfo(testID) if integrityInfo == nil { http.Error(w, "Integrity info not found", http.StatusNotFound) return } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(integrityInfo) }) // 新增: 恢复测试API,用于断电测试后的恢复与校验 mux.HandleFunc("/recovery", func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } var req struct { TestType string `json:"test_type"` TestDir string `json:"test_dir"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Invalid request body", http.StatusBadRequest) return } logger.Infof("收到恢复测试请求: %+v", req) // 创建恢复测试实例 test, err := runner.factory.CreateTestCase(req.TestType) if err != nil || test == nil { http.Error(w, fmt.Sprintf("无法创建测试实例: %v", err), http.StatusBadRequest) return } // 获取测试ID testID := test.Status().TestID // 执行恢复和数据完整性检查 go func() { ctx := context.Background() // 设置测试环境 logger.Info("设置恢复测试环境") if err := test.Setup(ctx, true); err != nil { logger.Errorf("设置恢复测试环境失败: %v", err) runner.sendErrorUpdate(testID, fmt.Sprintf("设置恢复测试环境失败: %v", err)) return } // 数据完整性检查 logger.Info("执行数据完整性检查") runner.sendStatusUpdate(test) // 检查并获取数据完整性信息 if powerTest, ok := test.(*testcase.PowerLossTest); ok { integrityInfo := powerTest.CheckIntegrity() go func() { time.Sleep(1 * time.Second) runner.sendIntegrityUpdate(testID, "开始数据完整性检查", nil) }() // 保存完整性信息 runner.SaveIntegrityInfo(testID, integrityInfo) // 发送完整性信息 runner.sendIntegrityUpdate(testID, "数据完整性检查完成", integrityInfo) logger.Infof("恢复测试完成: 丢失数据: %.2f MB", integrityInfo.DataLossMB) } else { logger.Error("不是断电测试实例,无法执行数据完整性检查") runner.sendErrorUpdate(testID, "不是断电测试实例,无法执行数据完整性检查") } // 清理测试环境 logger.Info("清理恢复测试环境") if err := test.Cleanup(ctx); err != nil { logger.Errorf("清理恢复测试环境失败: %v", err) } }() // 返回接受响应 resp := model.TestResponse{ RequestID: testID, Status: "accepted", Message: "恢复测试已接受并开始执行", ServerTime: time.Now(), } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusAccepted) json.NewEncoder(w).Encode(resp) }) // 启动服务器 addr := fmt.Sprintf("%s:%d", cfg.Server.ListenAddr, cfg.Server.Port) server := &http.Server{ Addr: addr, Handler: mux, } go func() { logger.Infof("服务器启动在 %s", addr) if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { logger.Fatalf("服务器启动失败: %v", err) } }() return server } func main() { flag.Parse() // 初始化日志级别 var level logrus.Level switch logLevel { case "debug": level = logrus.DebugLevel case "info": level = logrus.InfoLevel case "warn": level = logrus.WarnLevel case "error": level = logrus.ErrorLevel default: level = logrus.InfoLevel } // 初始化日志 logger := logrus.New() logger.SetLevel(level) logger.SetFormatter(&logrus.TextFormatter{ FullTimestamp: true, TimestampFormat: "2006-01-02 15:04:05", }) // 加载配置 logger.Infof("加载配置文件: %s", configFile) cfg, err := config.Load(configFile) if err != nil { logger.Fatalf("加载配置失败: %v", err) } // 初始化日志文件 if cfg.Server.LogFile != "" { utils.InitLogger(cfg.Server.LogFile, level) logger = utils.Logger } // 创建测试运行器 runner := NewTestRunner(cfg, logger) // 启动服务器 server := StartServer(cfg, runner, logger) // 等待终止信号 stop := make(chan os.Signal, 1) signal.Notify(stop, os.Interrupt, syscall.SIGTERM) <-stop logger.Info("正在关闭服务器...") ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := server.Shutdown(ctx); err != nil { logger.Fatalf("服务器强制关闭: %v", err) } logger.Info("服务器已优雅关闭") }