299 lines
7.2 KiB
Go
299 lines
7.2 KiB
Go
package main
|
||
|
||
import (
|
||
"bytes"
|
||
"encoding/json"
|
||
"flag"
|
||
"fmt"
|
||
"net/http"
|
||
"time"
|
||
|
||
"plp-test/internal/config"
|
||
"plp-test/internal/model"
|
||
"plp-test/internal/utils"
|
||
|
||
"github.com/google/uuid"
|
||
"github.com/sirupsen/logrus"
|
||
)
|
||
|
||
var (
|
||
configFile string
|
||
logLevel string
|
||
serverAddr string
|
||
testType string
|
||
timeout int
|
||
dataSizeMB int
|
||
blockSize int
|
||
concurrent bool
|
||
)
|
||
|
||
func init() {
|
||
flag.StringVar(&configFile, "config", "config.yaml", "配置文件路径")
|
||
flag.StringVar(&logLevel, "log-level", "info", "日志级别 (debug, info, warn, error)")
|
||
flag.StringVar(&serverAddr, "server", "", "服务器地址,格式为 host:port")
|
||
flag.StringVar(&testType, "test", "sequential", "测试类型 (sequential, random, mixed, concurrent, power_loss, stability, all)")
|
||
flag.IntVar(&timeout, "timeout", 0, "测试超时时间(秒)")
|
||
flag.IntVar(&dataSizeMB, "data-size", 0, "测试数据大小(MB)")
|
||
flag.IntVar(&blockSize, "block-size", 0, "数据块大小(KB)")
|
||
flag.BoolVar(&concurrent, "concurrent", false, "是否并发执行所有测试")
|
||
}
|
||
|
||
// Client 客户端
|
||
type Client struct {
|
||
config *config.Config
|
||
logger *logrus.Logger
|
||
httpClient *http.Client
|
||
serverAddr string
|
||
clientID string
|
||
}
|
||
|
||
// NewClient 创建客户端
|
||
func NewClient(cfg *config.Config, logger *logrus.Logger, serverAddr string) *Client {
|
||
if serverAddr == "" {
|
||
serverAddr = cfg.Client.ServerAddr
|
||
}
|
||
|
||
return &Client{
|
||
config: cfg,
|
||
logger: logger,
|
||
httpClient: &http.Client{
|
||
Timeout: time.Duration(cfg.Client.TimeoutSec) * time.Second,
|
||
},
|
||
serverAddr: serverAddr,
|
||
clientID: uuid.New().String(),
|
||
}
|
||
}
|
||
|
||
// RunTest 运行测试
|
||
func (c *Client) RunTest(testType string, dataSizeMB, blockSize int) error {
|
||
c.logger.Infof("运行测试 %s", testType)
|
||
|
||
// 准备请求数据
|
||
req := model.TestRequest{
|
||
TestType: testType,
|
||
DataSizeMB: dataSizeMB,
|
||
BlockSize: blockSize,
|
||
Concurrency: c.config.Client.Concurrency,
|
||
ClientID: c.clientID,
|
||
RequestTime: time.Now(),
|
||
Parameters: make(map[string]string),
|
||
}
|
||
|
||
// 设置默认值
|
||
if req.DataSizeMB == 0 {
|
||
req.DataSizeMB = c.config.Test.DataSizeMB
|
||
}
|
||
if req.BlockSize == 0 {
|
||
req.BlockSize = c.config.Test.BlockSize
|
||
}
|
||
|
||
// 序列化请求数据
|
||
reqData, err := json.Marshal(req)
|
||
if err != nil {
|
||
return fmt.Errorf("序列化请求数据失败: %v", err)
|
||
}
|
||
|
||
// 发送请求
|
||
url := fmt.Sprintf("http://%s/run", c.serverAddr)
|
||
resp, err := c.httpClient.Post(url, "application/json", bytes.NewBuffer(reqData))
|
||
if err != nil {
|
||
return fmt.Errorf("发送请求失败: %v", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
// 检查响应状态
|
||
if resp.StatusCode != http.StatusAccepted {
|
||
return fmt.Errorf("服务器返回错误状态码: %d", resp.StatusCode)
|
||
}
|
||
|
||
// 解析响应
|
||
var testResp model.TestResponse
|
||
if err := json.NewDecoder(resp.Body).Decode(&testResp); err != nil {
|
||
return fmt.Errorf("解析响应失败: %v", err)
|
||
}
|
||
|
||
c.logger.Infof("测试请求已接受,RequestID: %s", testResp.RequestID)
|
||
|
||
// 监控测试状态
|
||
return c.MonitorTestStatus(testResp.RequestID)
|
||
}
|
||
|
||
// MonitorTestStatus 监控测试状态
|
||
func (c *Client) MonitorTestStatus(testID string) error {
|
||
c.logger.Infof("监控测试状态: %s", testID)
|
||
|
||
for {
|
||
// 获取测试状态
|
||
url := fmt.Sprintf("http://%s/status?test_id=%s", c.serverAddr, testID)
|
||
resp, err := c.httpClient.Get(url)
|
||
if err != nil {
|
||
c.logger.Warnf("获取测试状态失败: %v", err)
|
||
time.Sleep(2 * time.Second)
|
||
continue
|
||
}
|
||
|
||
// 检查响应状态
|
||
if resp.StatusCode == http.StatusNotFound {
|
||
c.logger.Infof("测试 %s 已完成", testID)
|
||
resp.Body.Close()
|
||
break
|
||
}
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
resp.Body.Close()
|
||
c.logger.Warnf("服务器返回错误状态码: %d", resp.StatusCode)
|
||
time.Sleep(2 * time.Second)
|
||
continue
|
||
}
|
||
|
||
// 解析响应
|
||
var status model.TestStatus
|
||
if err := json.NewDecoder(resp.Body).Decode(&status); err != nil {
|
||
resp.Body.Close()
|
||
c.logger.Warnf("解析响应失败: %v", err)
|
||
time.Sleep(2 * time.Second)
|
||
continue
|
||
}
|
||
resp.Body.Close()
|
||
|
||
// 显示测试状态
|
||
c.logger.Infof("测试状态: %s, 进度: %.2f%%, 阶段: %s",
|
||
status.Status, status.Progress, status.CurrentPhase)
|
||
|
||
// 检查测试是否结束
|
||
if status.Status == "completed" || status.Status == "failed" || status.Status == "aborted" {
|
||
c.logger.Infof("测试 %s %s", testID, status.Status)
|
||
break
|
||
}
|
||
|
||
// 等待一段时间再次检查
|
||
time.Sleep(1 * time.Second)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// CheckServerHealth 检查服务器健康状态
|
||
func (c *Client) CheckServerHealth() error {
|
||
c.logger.Info("检查服务器健康状态")
|
||
|
||
url := fmt.Sprintf("http://%s/health", c.serverAddr)
|
||
resp, err := c.httpClient.Get(url)
|
||
if err != nil {
|
||
return fmt.Errorf("连接服务器失败: %v", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
return fmt.Errorf("服务器返回错误状态码: %d", resp.StatusCode)
|
||
}
|
||
|
||
var health model.HealthStatus
|
||
if err := json.NewDecoder(resp.Body).Decode(&health); err != nil {
|
||
return fmt.Errorf("解析响应失败: %v", err)
|
||
}
|
||
|
||
c.logger.Infof("服务器状态: %s, 消息: %s", health.Status, health.Message)
|
||
return nil
|
||
}
|
||
|
||
// RunAllTests 运行所有测试
|
||
func (c *Client) RunAllTests(concurrent bool) error {
|
||
c.logger.Info("运行所有测试")
|
||
|
||
tests := c.config.Test.EnabledTests
|
||
|
||
if concurrent {
|
||
c.logger.Info("并发执行所有测试")
|
||
errCh := make(chan error, len(tests))
|
||
|
||
for _, test := range tests {
|
||
go func(t string) {
|
||
errCh <- c.RunTest(t, dataSizeMB, blockSize)
|
||
}(test)
|
||
}
|
||
|
||
// 等待所有测试完成
|
||
for range tests {
|
||
if err := <-errCh; err != nil {
|
||
c.logger.Errorf("测试失败: %v", err)
|
||
}
|
||
}
|
||
} else {
|
||
c.logger.Info("顺序执行所有测试")
|
||
for _, test := range tests {
|
||
if err := c.RunTest(test, dataSizeMB, blockSize); err != nil {
|
||
c.logger.Errorf("测试 %s 失败: %v", test, err)
|
||
}
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func main() {
|
||
flag.Parse()
|
||
|
||
// 初始化日志级别
|
||
var level logrus.Level
|
||
switch logLevel {
|
||
case "debug":
|
||
level = logrus.DebugLevel
|
||
case "info":
|
||
level = logrus.InfoLevel
|
||
case "warn":
|
||
level = logrus.WarnLevel
|
||
case "error":
|
||
level = logrus.ErrorLevel
|
||
default:
|
||
level = logrus.InfoLevel
|
||
}
|
||
|
||
// 初始化日志
|
||
logger := logrus.New()
|
||
logger.SetLevel(level)
|
||
logger.SetFormatter(&logrus.TextFormatter{
|
||
FullTimestamp: true,
|
||
TimestampFormat: "2006-01-02 15:04:05",
|
||
})
|
||
|
||
// 加载配置
|
||
logger.Infof("加载配置文件: %s", configFile)
|
||
cfg, err := config.Load(configFile)
|
||
if err != nil {
|
||
logger.Fatalf("加载配置失败: %v", err)
|
||
}
|
||
|
||
// 初始化日志文件
|
||
if cfg.Client.LogFile != "" {
|
||
utils.InitLogger(cfg.Client.LogFile, level)
|
||
logger = utils.Logger
|
||
}
|
||
|
||
// 设置超时时间
|
||
if timeout > 0 {
|
||
cfg.Client.TimeoutSec = timeout
|
||
}
|
||
|
||
// 创建客户端
|
||
client := NewClient(cfg, logger, serverAddr)
|
||
|
||
// 检查服务器健康状态
|
||
if err := client.CheckServerHealth(); err != nil {
|
||
logger.Fatalf("服务器健康检查失败: %v", err)
|
||
}
|
||
|
||
// 运行测试
|
||
if testType == "all" {
|
||
if err := client.RunAllTests(concurrent); err != nil {
|
||
logger.Fatalf("运行所有测试失败: %v", err)
|
||
}
|
||
} else {
|
||
if err := client.RunTest(testType, dataSizeMB, blockSize); err != nil {
|
||
logger.Fatalf("运行测试 %s 失败: %v", testType, err)
|
||
}
|
||
}
|
||
|
||
logger.Info("所有测试已完成")
|
||
}
|