597 lines
15 KiB
Go
597 lines
15 KiB
Go
package main
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"flag"
|
||
"fmt"
|
||
"net/http"
|
||
"os"
|
||
"os/signal"
|
||
"sync"
|
||
"syscall"
|
||
"time"
|
||
|
||
"plp-test/internal/config"
|
||
"plp-test/internal/model"
|
||
"plp-test/internal/testcase"
|
||
"plp-test/internal/utils"
|
||
|
||
"github.com/sirupsen/logrus"
|
||
)
|
||
|
||
var (
|
||
configFile string
|
||
logLevel string
|
||
)
|
||
|
||
func init() {
|
||
flag.StringVar(&configFile, "config", "config.yaml", "配置文件路径")
|
||
flag.StringVar(&logLevel, "log-level", "info", "日志级别 (debug, info, warn, error)")
|
||
}
|
||
|
||
// TestRunner 测试运行器
|
||
type TestRunner struct {
|
||
config *config.Config
|
||
logger *logrus.Logger
|
||
factory *testcase.TestCaseFactory
|
||
tests map[string]testcase.TestCase
|
||
testsMu sync.RWMutex
|
||
testResult map[string]*model.TestResult
|
||
resultMu sync.RWMutex
|
||
streams map[string]map[string]http.ResponseWriter
|
||
streamsMu sync.RWMutex
|
||
integrityInfo map[string]*model.IntegrityInfo
|
||
integrityMu sync.RWMutex
|
||
}
|
||
|
||
// NewTestRunner 创建测试运行器
|
||
func NewTestRunner(cfg *config.Config, logger *logrus.Logger) *TestRunner {
|
||
return &TestRunner{
|
||
config: cfg,
|
||
logger: logger,
|
||
factory: testcase.NewTestCaseFactory(cfg, logger),
|
||
tests: make(map[string]testcase.TestCase),
|
||
testResult: make(map[string]*model.TestResult),
|
||
streams: make(map[string]map[string]http.ResponseWriter),
|
||
integrityInfo: make(map[string]*model.IntegrityInfo),
|
||
}
|
||
}
|
||
|
||
// RunTest 运行指定的测试
|
||
func (r *TestRunner) RunTest(testType string) (*model.TestResult, error) {
|
||
r.logger.Infof("准备运行测试: %s", testType)
|
||
|
||
// 创建测试实例
|
||
test, err := r.factory.CreateTestCase(testType)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("创建测试用例失败: %v", err)
|
||
}
|
||
if test == nil {
|
||
return nil, fmt.Errorf("未找到测试用例: %s", testType)
|
||
}
|
||
|
||
// 存储测试实例
|
||
testID := test.Status().TestID
|
||
r.testsMu.Lock()
|
||
r.tests[testID] = test
|
||
r.testsMu.Unlock()
|
||
|
||
// 创建上下文以便可以取消测试
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
defer cancel()
|
||
|
||
// 发送测试开始状态更新
|
||
r.sendStatusUpdate(test)
|
||
|
||
// 设置测试环境
|
||
r.logger.Info("设置测试环境")
|
||
if err := test.Setup(ctx, false); err != nil {
|
||
r.logger.Errorf("设置测试环境失败: %v", err)
|
||
r.sendErrorUpdate(testID, fmt.Sprintf("设置测试环境失败: %v", err))
|
||
return nil, err
|
||
}
|
||
|
||
// 启动状态监控协程
|
||
statusDone := make(chan struct{})
|
||
go func() {
|
||
defer close(statusDone)
|
||
ticker := time.NewTicker(200 * time.Millisecond)
|
||
defer ticker.Stop()
|
||
|
||
for {
|
||
select {
|
||
case <-ctx.Done():
|
||
return
|
||
case <-ticker.C:
|
||
// 发送状态更新
|
||
r.sendStatusUpdate(test)
|
||
}
|
||
}
|
||
}()
|
||
|
||
// 运行测试
|
||
r.logger.Info("运行测试")
|
||
result, err := test.Run(ctx)
|
||
if err != nil {
|
||
r.logger.Errorf("测试运行失败: %v", err)
|
||
r.sendErrorUpdate(testID, fmt.Sprintf("测试运行失败: %v", err))
|
||
|
||
// 尝试清理
|
||
cleanupErr := test.Cleanup(ctx)
|
||
if cleanupErr != nil {
|
||
r.logger.Errorf("测试清理失败: %v", cleanupErr)
|
||
}
|
||
return nil, err
|
||
}
|
||
|
||
// 清理测试环境
|
||
r.logger.Info("清理测试环境")
|
||
if err := test.Cleanup(ctx); err != nil {
|
||
r.logger.Errorf("测试清理失败: %v", err)
|
||
r.sendErrorUpdate(testID, fmt.Sprintf("测试清理失败: %v", err))
|
||
return nil, err
|
||
}
|
||
|
||
// 停止状态监控
|
||
cancel()
|
||
<-statusDone
|
||
|
||
// 存储测试结果
|
||
r.resultMu.Lock()
|
||
r.testResult[testID] = result
|
||
r.resultMu.Unlock()
|
||
|
||
// 移除测试实例
|
||
r.testsMu.Lock()
|
||
delete(r.tests, testID)
|
||
r.testsMu.Unlock()
|
||
|
||
// 发送完成通知
|
||
r.sendCompletionUpdate(testID, result)
|
||
|
||
r.logger.Infof("测试 %s 完成", testType)
|
||
return result, nil
|
||
}
|
||
|
||
// sendStatusUpdate 发送状态更新
|
||
func (r *TestRunner) sendStatusUpdate(test testcase.TestCase) {
|
||
status := test.Status()
|
||
update := model.StreamUpdate{
|
||
Type: "status",
|
||
TestID: status.TestID,
|
||
Timestamp: time.Now(),
|
||
Progress: status.Progress,
|
||
CurrentPhase: status.CurrentPhase,
|
||
Message: status.Message,
|
||
Data: status,
|
||
}
|
||
r.SendStreamUpdate(status.TestID, update)
|
||
}
|
||
|
||
// sendErrorUpdate 发送错误更新
|
||
func (r *TestRunner) sendErrorUpdate(testID, message string) {
|
||
update := model.StreamUpdate{
|
||
Type: "error",
|
||
TestID: testID,
|
||
Timestamp: time.Now(),
|
||
Message: message,
|
||
}
|
||
r.SendStreamUpdate(testID, update)
|
||
}
|
||
|
||
// sendCompletionUpdate 发送完成更新
|
||
func (r *TestRunner) sendCompletionUpdate(testID string, result *model.TestResult) {
|
||
update := model.StreamUpdate{
|
||
Type: "completion",
|
||
TestID: testID,
|
||
Timestamp: time.Now(),
|
||
Progress: 100,
|
||
Message: "测试完成",
|
||
Data: result,
|
||
}
|
||
r.SendStreamUpdate(testID, update)
|
||
}
|
||
|
||
// sendIntegrityUpdate 发送完整性更新
|
||
func (r *TestRunner) sendIntegrityUpdate(testID string, message string, info *model.IntegrityInfo) {
|
||
update := model.StreamUpdate{
|
||
Type: "integrity",
|
||
TestID: testID,
|
||
Timestamp: time.Now(),
|
||
Message: message,
|
||
Data: info,
|
||
}
|
||
r.SendStreamUpdate(testID, update)
|
||
}
|
||
|
||
// GetTestStatus 获取测试状态
|
||
func (r *TestRunner) GetTestStatus(testID string) *model.TestStatus {
|
||
r.testsMu.RLock()
|
||
defer r.testsMu.RUnlock()
|
||
|
||
if test, ok := r.tests[testID]; ok {
|
||
return test.Status()
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// GetAllTestStatus 获取所有测试状态
|
||
func (r *TestRunner) GetAllTestStatus() []*model.TestStatus {
|
||
r.testsMu.RLock()
|
||
defer r.testsMu.RUnlock()
|
||
|
||
statuses := make([]*model.TestStatus, 0, len(r.tests))
|
||
for _, test := range r.tests {
|
||
statuses = append(statuses, test.Status())
|
||
}
|
||
return statuses
|
||
}
|
||
|
||
// RegisterStream 注册流式连接
|
||
func (r *TestRunner) RegisterStream(testID, clientID string, w http.ResponseWriter) {
|
||
r.streamsMu.Lock()
|
||
defer r.streamsMu.Unlock()
|
||
|
||
if _, ok := r.streams[testID]; !ok {
|
||
r.streams[testID] = make(map[string]http.ResponseWriter)
|
||
}
|
||
r.streams[testID][clientID] = w
|
||
r.logger.Infof("客户端 %s 已连接到测试 %s 的流", clientID, testID)
|
||
}
|
||
|
||
// UnregisterStream 注销流式连接
|
||
func (r *TestRunner) UnregisterStream(testID, clientID string) {
|
||
r.streamsMu.Lock()
|
||
defer r.streamsMu.Unlock()
|
||
|
||
if clients, ok := r.streams[testID]; ok {
|
||
delete(clients, clientID)
|
||
r.logger.Infof("客户端 %s 已断开与测试 %s 的流连接", clientID, testID)
|
||
}
|
||
}
|
||
|
||
// SendStreamUpdate 发送流式更新
|
||
func (r *TestRunner) SendStreamUpdate(testID string, update interface{}) {
|
||
r.streamsMu.RLock()
|
||
defer r.streamsMu.RUnlock()
|
||
|
||
clients, ok := r.streams[testID]
|
||
if !ok || len(clients) == 0 {
|
||
return
|
||
}
|
||
|
||
data, err := json.Marshal(update)
|
||
if err != nil {
|
||
r.logger.Errorf("无法序列化流更新: %v", err)
|
||
return
|
||
}
|
||
|
||
for clientID, w := range clients {
|
||
// 使用Server-Sent Events格式
|
||
_, err := fmt.Fprintf(w, "data: %s\n\n", data)
|
||
if err != nil {
|
||
r.logger.Warnf("向客户端 %s 发送更新失败: %v", clientID, err)
|
||
} else {
|
||
if f, ok := w.(http.Flusher); ok {
|
||
f.Flush()
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// SaveIntegrityInfo 保存完整性信息
|
||
func (r *TestRunner) SaveIntegrityInfo(testID string, info *model.IntegrityInfo) {
|
||
r.integrityMu.Lock()
|
||
defer r.integrityMu.Unlock()
|
||
r.integrityInfo[testID] = info
|
||
}
|
||
|
||
// GetIntegrityInfo 获取完整性信息
|
||
func (r *TestRunner) GetIntegrityInfo(testID string) *model.IntegrityInfo {
|
||
r.integrityMu.RLock()
|
||
defer r.integrityMu.RUnlock()
|
||
return r.integrityInfo[testID]
|
||
}
|
||
|
||
// StartServer 启动HTTP服务器
|
||
func StartServer(cfg *config.Config, runner *TestRunner, logger *logrus.Logger) *http.Server {
|
||
mux := http.NewServeMux()
|
||
|
||
// 健康检查接口
|
||
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
|
||
health := &model.HealthStatus{
|
||
Status: "ok",
|
||
Timestamp: time.Now(),
|
||
Message: "服务正常运行",
|
||
}
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(health)
|
||
})
|
||
|
||
// 运行测试接口
|
||
mux.HandleFunc("/run", func(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodPost {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
var req model.TestRequest
|
||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
logger.Infof("收到测试请求: %+v", req)
|
||
|
||
// 异步运行测试
|
||
go func() {
|
||
result, err := runner.RunTest(req.TestType)
|
||
if err != nil {
|
||
logger.Errorf("测试运行失败: %v", err)
|
||
} else {
|
||
logger.Infof("测试完成: %+v", result)
|
||
}
|
||
}()
|
||
|
||
resp := model.TestResponse{
|
||
RequestID: req.TestType + "-" + time.Now().Format("20060102-150405"),
|
||
Status: "accepted",
|
||
Message: "测试已接受并开始执行",
|
||
ServerTime: time.Now(),
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
w.WriteHeader(http.StatusAccepted)
|
||
json.NewEncoder(w).Encode(resp)
|
||
})
|
||
|
||
// 获取测试状态接口
|
||
mux.HandleFunc("/status", func(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodGet {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
testID := r.URL.Query().Get("test_id")
|
||
var status interface{}
|
||
|
||
if testID == "" {
|
||
// 获取所有测试状态
|
||
status = runner.GetAllTestStatus()
|
||
} else {
|
||
// 获取指定测试状态
|
||
status = runner.GetTestStatus(testID)
|
||
if status == nil {
|
||
http.Error(w, "Test not found", http.StatusNotFound)
|
||
return
|
||
}
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(status)
|
||
})
|
||
|
||
// 新增: 实时数据进度流式API
|
||
mux.HandleFunc("/stream", func(w http.ResponseWriter, r *http.Request) {
|
||
testID := r.URL.Query().Get("test_id")
|
||
if testID == "" {
|
||
http.Error(w, "Missing test_id", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
// 设置响应头,支持SSE (Server-Sent Events)
|
||
w.Header().Set("Content-Type", "text/event-stream")
|
||
w.Header().Set("Cache-Control", "no-cache")
|
||
w.Header().Set("Connection", "keep-alive")
|
||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||
|
||
// 创建完成通道
|
||
doneCh := make(chan struct{})
|
||
defer close(doneCh)
|
||
|
||
// 注册客户端连接
|
||
clientID := r.URL.Query().Get("client_id")
|
||
runner.RegisterStream(testID, clientID, w)
|
||
defer runner.UnregisterStream(testID, clientID)
|
||
|
||
// 保持连接直到客户端断开
|
||
select {
|
||
case <-r.Context().Done():
|
||
runner.logger.Infof("connection closed by client %s", clientID)
|
||
return
|
||
case <-doneCh:
|
||
runner.logger.Infof("connection closed by server for client %s", clientID)
|
||
return
|
||
}
|
||
})
|
||
|
||
// 新增: 数据完整性检测API
|
||
mux.HandleFunc("/integrity", func(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodGet {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
testID := r.URL.Query().Get("test_id")
|
||
if testID == "" {
|
||
http.Error(w, "Missing test_id", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
// 获取测试的数据完整性信息
|
||
integrityInfo := runner.GetIntegrityInfo(testID)
|
||
if integrityInfo == nil {
|
||
http.Error(w, "Integrity info not found", http.StatusNotFound)
|
||
return
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
json.NewEncoder(w).Encode(integrityInfo)
|
||
})
|
||
|
||
// 新增: 恢复测试API,用于断电测试后的恢复与校验
|
||
mux.HandleFunc("/recovery", func(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodPost {
|
||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
|
||
var req struct {
|
||
TestType string `json:"test_type"`
|
||
TestDir string `json:"test_dir"`
|
||
}
|
||
|
||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
logger.Infof("收到恢复测试请求: %+v", req)
|
||
|
||
// 创建恢复测试实例
|
||
test, err := runner.factory.CreateTestCase(req.TestType)
|
||
if err != nil || test == nil {
|
||
http.Error(w, fmt.Sprintf("无法创建测试实例: %v", err), http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
// 获取测试ID
|
||
testID := test.Status().TestID
|
||
|
||
// 执行恢复和数据完整性检查
|
||
go func() {
|
||
ctx := context.Background()
|
||
|
||
// 设置测试环境
|
||
logger.Info("设置恢复测试环境")
|
||
if err := test.Setup(ctx, true); err != nil {
|
||
logger.Errorf("设置恢复测试环境失败: %v", err)
|
||
runner.sendErrorUpdate(testID, fmt.Sprintf("设置恢复测试环境失败: %v", err))
|
||
return
|
||
}
|
||
|
||
// 数据完整性检查
|
||
logger.Info("执行数据完整性检查")
|
||
runner.sendStatusUpdate(test)
|
||
|
||
// 检查并获取数据完整性信息
|
||
if powerTest, ok := test.(*testcase.PowerLossTest); ok {
|
||
integrityInfo := powerTest.CheckIntegrity()
|
||
|
||
go func() {
|
||
time.Sleep(1 * time.Second)
|
||
runner.sendIntegrityUpdate(testID, "开始数据完整性检查", nil)
|
||
}()
|
||
|
||
// 保存完整性信息
|
||
runner.SaveIntegrityInfo(testID, integrityInfo)
|
||
|
||
// 发送完整性信息
|
||
|
||
runner.sendIntegrityUpdate(testID, "数据完整性检查完成", integrityInfo)
|
||
|
||
logger.Infof("恢复测试完成: 丢失数据: %.2f MB", integrityInfo.DataLossMB)
|
||
} else {
|
||
logger.Error("不是断电测试实例,无法执行数据完整性检查")
|
||
runner.sendErrorUpdate(testID, "不是断电测试实例,无法执行数据完整性检查")
|
||
}
|
||
|
||
// 清理测试环境
|
||
logger.Info("清理恢复测试环境")
|
||
if err := test.Cleanup(ctx); err != nil {
|
||
logger.Errorf("清理恢复测试环境失败: %v", err)
|
||
}
|
||
}()
|
||
|
||
// 返回接受响应
|
||
resp := model.TestResponse{
|
||
RequestID: testID,
|
||
Status: "accepted",
|
||
Message: "恢复测试已接受并开始执行",
|
||
ServerTime: time.Now(),
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "application/json")
|
||
w.WriteHeader(http.StatusAccepted)
|
||
json.NewEncoder(w).Encode(resp)
|
||
})
|
||
|
||
// 启动服务器
|
||
addr := fmt.Sprintf("%s:%d", cfg.Server.ListenAddr, cfg.Server.Port)
|
||
server := &http.Server{
|
||
Addr: addr,
|
||
Handler: mux,
|
||
}
|
||
|
||
go func() {
|
||
logger.Infof("服务器启动在 %s", addr)
|
||
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||
logger.Fatalf("服务器启动失败: %v", err)
|
||
}
|
||
}()
|
||
|
||
return server
|
||
}
|
||
|
||
func main() {
|
||
flag.Parse()
|
||
|
||
// 初始化日志级别
|
||
var level logrus.Level
|
||
switch logLevel {
|
||
case "debug":
|
||
level = logrus.DebugLevel
|
||
case "info":
|
||
level = logrus.InfoLevel
|
||
case "warn":
|
||
level = logrus.WarnLevel
|
||
case "error":
|
||
level = logrus.ErrorLevel
|
||
default:
|
||
level = logrus.InfoLevel
|
||
}
|
||
|
||
// 初始化日志
|
||
logger := logrus.New()
|
||
logger.SetLevel(level)
|
||
logger.SetFormatter(&logrus.TextFormatter{
|
||
FullTimestamp: true,
|
||
TimestampFormat: "2006-01-02 15:04:05",
|
||
})
|
||
|
||
// 加载配置
|
||
logger.Infof("加载配置文件: %s", configFile)
|
||
cfg, err := config.Load(configFile)
|
||
if err != nil {
|
||
logger.Fatalf("加载配置失败: %v", err)
|
||
}
|
||
|
||
// 初始化日志文件
|
||
if cfg.Server.LogFile != "" {
|
||
utils.InitLogger(cfg.Server.LogFile, level)
|
||
logger = utils.Logger
|
||
}
|
||
|
||
// 创建测试运行器
|
||
runner := NewTestRunner(cfg, logger)
|
||
|
||
// 启动服务器
|
||
server := StartServer(cfg, runner, logger)
|
||
|
||
// 等待终止信号
|
||
stop := make(chan os.Signal, 1)
|
||
signal.Notify(stop, os.Interrupt, syscall.SIGTERM)
|
||
<-stop
|
||
|
||
logger.Info("正在关闭服务器...")
|
||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||
defer cancel()
|
||
|
||
if err := server.Shutdown(ctx); err != nil {
|
||
logger.Fatalf("服务器强制关闭: %v", err)
|
||
}
|
||
|
||
logger.Info("服务器已优雅关闭")
|
||
}
|