132 lines
4.5 KiB
Go
132 lines
4.5 KiB
Go
// Copyright (c) 2017 VMware, Inc. All Rights Reserved.
|
|
//
|
|
// This product is licensed to you under the Apache License, Version 2.0 (the "License").
|
|
// You may not use this product except in compliance with the License.
|
|
//
|
|
// This product may include a number of subcomponents with separate copyright notices and
|
|
// license terms. Your use of these subcomponents is subject to the terms and conditions
|
|
// of the subcomponent's license, as noted in the LICENSE file.
|
|
|
|
// +build windows
|
|
|
|
package lightwave
|
|
|
|
import (
|
|
"encoding/base64"
|
|
"fmt"
|
|
"github.com/vmware/photon-controller-go-sdk/SSPI"
|
|
"math/rand"
|
|
"net"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
const gssTicketGrantFormatString = "grant_type=urn:vmware:grant_type:gss_ticket&gss_ticket=%s&context_id=%s&scope=%s"
|
|
|
|
// GetTokensFromWindowsLogInContext gets tokens based on Windows logged in context
|
|
// Here is how it works:
|
|
// 1. Get the SPN (Service Principal Name) in the format host/FQDN of lightwave. This is needed for SSPI/Kerberos protocol
|
|
// 2. Call Windows API AcquireCredentialsHandle() using SSPI library. This will give the current users credential handle
|
|
// 3. Using this handle call Windows API AcquireCredentialsHandle(). This will give you byte[]
|
|
// 4. Encode this byte[] and send it to OIDC server over HTTP (using POST)
|
|
// 5. OIDC server can send either of the following
|
|
// - Access tokens. In this case return access tokens to client
|
|
// - Error in the format: invalid_grant: gss_continue_needed:'context id':'token from server'
|
|
// 6. In case you get error, parse it and get the token from server
|
|
// 7. Feed this token to step 3 and repeat steps till you get the access tokens from server
|
|
func (client *OIDCClient) GetTokensFromWindowsLogInContext() (tokens *OIDCTokenResponse, err error) {
|
|
spn, err := client.buildSPN()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
auth, _ := SSPI.GetAuth("", "", spn, "")
|
|
|
|
userContext, err := auth.InitialBytes()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// In case of multiple req/res between client and server (as explained in above comment),
|
|
// server needs to maintain the mapping of context id -> token
|
|
// So we need to generate random string as a context id
|
|
// If we use same context id for all the requests, results can be erroneous
|
|
contextId := client.generateRandomString()
|
|
body := fmt.Sprintf(gssTicketGrantFormatString, url.QueryEscape(base64.StdEncoding.EncodeToString(userContext)), contextId, client.Options.TokenScope)
|
|
tokens, err = client.getToken(body)
|
|
|
|
for {
|
|
if err == nil {
|
|
break
|
|
}
|
|
|
|
// In case of error the response will be in format: invalid_grant: gss_continue_needed:'context id':'token from server'
|
|
gssToken := client.validateAndExtractGSSResponse(err, contextId)
|
|
if gssToken == "" {
|
|
return nil, err
|
|
}
|
|
|
|
data, err := base64.StdEncoding.DecodeString(gssToken)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
userContext, err := auth.NextBytes(data)
|
|
body := fmt.Sprintf(gssTicketGrantFormatString, url.QueryEscape(base64.StdEncoding.EncodeToString(userContext)), contextId, client.Options.TokenScope)
|
|
tokens, err = client.getToken(body)
|
|
}
|
|
|
|
return tokens, err
|
|
}
|
|
|
|
// Gets the SPN (Service Principal Name) in the format host/FQDN of lightwave
|
|
func (client *OIDCClient) buildSPN() (spn string, err error) {
|
|
u, err := url.Parse(client.Endpoint)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
host, _, err := net.SplitHostPort(u.Host)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
addr, err := net.LookupAddr(host)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
var s = strings.TrimSuffix(addr[0], ".")
|
|
return "host/" + s, nil
|
|
}
|
|
|
|
// validateAndExtractGSSResponse parse the error from server and returns token from server
|
|
// In case of error from the server, response will be in format: invalid_grant: gss_continue_needed:'context id':'token from server'
|
|
// So, we check for the above format in error and then return the token from server
|
|
// If error is not in above format, we return empty string
|
|
func (client *OIDCClient) validateAndExtractGSSResponse(err error, contextId string) string {
|
|
parts := strings.Split(err.Error(), ":")
|
|
if !(len(parts) == 4 && strings.TrimSpace(parts[1]) == "gss_continue_needed" && parts[2] == contextId) {
|
|
return ""
|
|
} else {
|
|
return parts[3]
|
|
}
|
|
}
|
|
|
|
func (client *OIDCClient) generateRandomString() string {
|
|
const length = 10
|
|
const asciiA = 65
|
|
const asciiZ = 90
|
|
rand.Seed(time.Now().UTC().UnixNano())
|
|
bytes := make([]byte, length)
|
|
for i := 0; i < length; i++ {
|
|
bytes[i] = byte(randInt(asciiA, asciiZ))
|
|
}
|
|
return string(bytes)
|
|
}
|
|
|
|
func randInt(min int, max int) int {
|
|
return min + rand.Intn(max-min)
|
|
}
|