|  |  |  | @@ -20,6 +20,7 @@ import ( | 
		
	
		
			
				|  |  |  |  | 	"fmt" | 
		
	
		
			
				|  |  |  |  | 	"net" | 
		
	
		
			
				|  |  |  |  | 	"os" | 
		
	
		
			
				|  |  |  |  | 	"strings" | 
		
	
		
			
				|  |  |  |  | 	"syscall" | 
		
	
		
			
				|  |  |  |  | 	"time" | 
		
	
		
			
				|  |  |  |  |  | 
		
	
	
		
			
				
					
					|  |  |  | @@ -30,8 +31,8 @@ import ( | 
		
	
		
			
				|  |  |  |  | // CmdConnect is used by agnhost Cobra. | 
		
	
		
			
				|  |  |  |  | var CmdConnect = &cobra.Command{ | 
		
	
		
			
				|  |  |  |  | 	Use:   "connect [host:port]", | 
		
	
		
			
				|  |  |  |  | 	Short: "Attempts a TCP or SCTP connection and returns useful errors", | 
		
	
		
			
				|  |  |  |  | 	Long: `Tries to open a TCP or SCTP connection to the given host and port. On error it prints an error message prefixed with a specific fixed string that test cases can check for: | 
		
	
		
			
				|  |  |  |  | 	Short: "Attempts a TCP, UDP or SCTP connection and returns useful errors", | 
		
	
		
			
				|  |  |  |  | 	Long: `Tries to open a TCP, UDP or SCTP connection to the given host and port. On error it prints an error message prefixed with a specific fixed string that test cases can check for: | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | * UNKNOWN - Generic/unknown (non-network) error (eg, bad arguments) | 
		
	
		
			
				|  |  |  |  | * TIMEOUT - The connection attempt timed out | 
		
	
	
		
			
				
					
					|  |  |  | @@ -42,12 +43,16 @@ var CmdConnect = &cobra.Command{ | 
		
	
		
			
				|  |  |  |  | 	Run:  main, | 
		
	
		
			
				|  |  |  |  | } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | var timeout time.Duration | 
		
	
		
			
				|  |  |  |  | var protocol string | 
		
	
		
			
				|  |  |  |  | var ( | 
		
	
		
			
				|  |  |  |  | 	timeout  time.Duration | 
		
	
		
			
				|  |  |  |  | 	protocol string | 
		
	
		
			
				|  |  |  |  | 	udpData  string | 
		
	
		
			
				|  |  |  |  | ) | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | func init() { | 
		
	
		
			
				|  |  |  |  | 	CmdConnect.Flags().DurationVar(&timeout, "timeout", time.Duration(0), "Maximum time before returning an error") | 
		
	
		
			
				|  |  |  |  | 	CmdConnect.Flags().StringVar(&protocol, "protocol", "tcp", "The protocol to use to perform the connection, can be tcp or sctp") | 
		
	
		
			
				|  |  |  |  | 	CmdConnect.Flags().StringVar(&protocol, "protocol", "tcp", "The protocol to use to perform the connection, can be tcp, udp or sctp") | 
		
	
		
			
				|  |  |  |  | 	CmdConnect.Flags().StringVar(&udpData, "udp-data", "hostname", "The UDP payload send to the server") | 
		
	
		
			
				|  |  |  |  | } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | func main(cmd *cobra.Command, args []string) { | 
		
	
	
		
			
				
					
					|  |  |  | @@ -55,6 +60,8 @@ func main(cmd *cobra.Command, args []string) { | 
		
	
		
			
				|  |  |  |  | 	switch protocol { | 
		
	
		
			
				|  |  |  |  | 	case "", "tcp": | 
		
	
		
			
				|  |  |  |  | 		connectTCP(dest, timeout) | 
		
	
		
			
				|  |  |  |  | 	case "udp": | 
		
	
		
			
				|  |  |  |  | 		connectUDP(dest, timeout, udpData) | 
		
	
		
			
				|  |  |  |  | 	case "sctp": | 
		
	
		
			
				|  |  |  |  | 		connectSCTP(dest, timeout) | 
		
	
		
			
				|  |  |  |  | 	default: | 
		
	
	
		
			
				
					
					|  |  |  | @@ -125,3 +132,54 @@ func connectSCTP(dest string, timeout time.Duration) { | 
		
	
		
			
				|  |  |  |  | 		os.Exit(1) | 
		
	
		
			
				|  |  |  |  | 	} | 
		
	
		
			
				|  |  |  |  | } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | func connectUDP(dest string, timeout time.Duration, data string) { | 
		
	
		
			
				|  |  |  |  | 	var ( | 
		
	
		
			
				|  |  |  |  | 		readBytes int | 
		
	
		
			
				|  |  |  |  | 		buf       = make([]byte, 1024) | 
		
	
		
			
				|  |  |  |  | 	) | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | 	if _, err := net.ResolveUDPAddr("udp", dest); err != nil { | 
		
	
		
			
				|  |  |  |  | 		fmt.Fprintf(os.Stderr, "DNS: %v\n", err) | 
		
	
		
			
				|  |  |  |  | 		os.Exit(1) | 
		
	
		
			
				|  |  |  |  | 	} | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | 	conn, err := net.Dial("udp", dest) | 
		
	
		
			
				|  |  |  |  | 	if err != nil { | 
		
	
		
			
				|  |  |  |  | 		fmt.Fprintf(os.Stderr, "OTHER: %v\n", err) | 
		
	
		
			
				|  |  |  |  | 		os.Exit(1) | 
		
	
		
			
				|  |  |  |  | 	} | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | 	if timeout > 0 { | 
		
	
		
			
				|  |  |  |  | 		if err = conn.SetDeadline(time.Now().Add(timeout)); err != nil { | 
		
	
		
			
				|  |  |  |  | 			fmt.Fprintf(os.Stderr, "OTHER: %v\n", err) | 
		
	
		
			
				|  |  |  |  | 			os.Exit(1) | 
		
	
		
			
				|  |  |  |  | 		} | 
		
	
		
			
				|  |  |  |  | 	} | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | 	if _, err = conn.Write([]byte(fmt.Sprintf("%s\n", data))); err != nil { | 
		
	
		
			
				|  |  |  |  | 		parseUDPErrorAndExit(err) | 
		
	
		
			
				|  |  |  |  | 	} | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | 	if readBytes, err = conn.Read(buf); err != nil { | 
		
	
		
			
				|  |  |  |  | 		parseUDPErrorAndExit(err) | 
		
	
		
			
				|  |  |  |  | 	} | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | 	// ensure the response from UDP server | 
		
	
		
			
				|  |  |  |  | 	if readBytes == 0 { | 
		
	
		
			
				|  |  |  |  | 		fmt.Fprintf(os.Stderr, "OTHER: No data received from the server. Cannot guarantee the server received the request.\n") | 
		
	
		
			
				|  |  |  |  | 		os.Exit(1) | 
		
	
		
			
				|  |  |  |  | 	} | 
		
	
		
			
				|  |  |  |  | } | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | func parseUDPErrorAndExit(err error) { | 
		
	
		
			
				|  |  |  |  | 	neterr, ok := err.(net.Error) | 
		
	
		
			
				|  |  |  |  | 	if ok && neterr.Timeout() { | 
		
	
		
			
				|  |  |  |  | 		fmt.Fprintf(os.Stderr, "TIMEOUT: %v\n", err) | 
		
	
		
			
				|  |  |  |  | 	} else if strings.Contains(err.Error(), "connection refused") { | 
		
	
		
			
				|  |  |  |  | 		fmt.Fprintf(os.Stderr, "REFUSED: %v\n", err) | 
		
	
		
			
				|  |  |  |  | 	} else { | 
		
	
		
			
				|  |  |  |  | 		fmt.Fprintf(os.Stderr, "UNKNOWN: %v\n", err) | 
		
	
		
			
				|  |  |  |  | 	} | 
		
	
		
			
				|  |  |  |  | 	os.Exit(1) | 
		
	
		
			
				|  |  |  |  | } | 
		
	
	
		
			
				
					
					| 
							
							
							
						 |  |  |   |