func TestHTTP_Fallback_Disabled(t *testing.T) { handler1 := http.NewServeMux() handler2 := http.NewServeMux() handler3 := http.NewServeMux() coreConfig := &vault.CoreConfig{ LogicalBackends: map[string]logical.Factory{ "transit": transit.Factory, }, ClusterAddr: "empty", } // Chicken-and-egg: Handler needs a core. So we create handlers first, then // add routes chained to a Handler-created handler. cores := vault.TestCluster(t, []http.Handler{handler1, handler2, handler3}, coreConfig, true) for _, core := range cores { defer core.CloseListeners() } handler1.Handle("/", Handler(cores[0].Core)) handler2.Handle("/", Handler(cores[1].Core)) handler3.Handle("/", Handler(cores[2].Core)) // make it easy to get access to the active core := cores[0].Core vault.TestWaitActive(t, core) root := cores[0].Root addrs := []string{ fmt.Sprintf("https://127.0.0.1:%d", cores[1].Listeners[0].Address.Port), fmt.Sprintf("https://127.0.0.1:%d", cores[2].Listeners[0].Address.Port), } for _, addr := range addrs { config := api.DefaultConfig() config.Address = addr config.HttpClient = cleanhttp.DefaultClient() config.HttpClient.Transport.(*http.Transport).TLSClientConfig = cores[0].TLSConfig client, err := api.NewClient(config) if err != nil { t.Fatal(err) } client.SetToken(root) secret, err := client.Auth().Token().LookupSelf() if err != nil { t.Fatal(err) } if secret == nil { t.Fatal("secret is nil") } if secret.Data["id"].(string) != root { t.Fatal("token mismatch") } } }
// This function recreates the fuzzy testing from transit to pipe a large // number of requests from the standbys to the active node. func TestHTTP_Forwarding_Stress(t *testing.T) { testPlaintext := "the quick brown fox" testPlaintextB64 := "dGhlIHF1aWNrIGJyb3duIGZveA==" handler1 := http.NewServeMux() handler2 := http.NewServeMux() handler3 := http.NewServeMux() coreConfig := &vault.CoreConfig{ LogicalBackends: map[string]logical.Factory{ "transit": transit.Factory, }, } // Chicken-and-egg: Handler needs a core. So we create handlers first, then // add routes chained to a Handler-created handler. cores := vault.TestCluster(t, []http.Handler{handler1, handler2, handler3}, coreConfig, true) for _, core := range cores { defer core.CloseListeners() } handler1.Handle("/", Handler(cores[0].Core)) handler2.Handle("/", Handler(cores[1].Core)) handler3.Handle("/", Handler(cores[2].Core)) // make it easy to get access to the active core := cores[0].Core vault.TestWaitActive(t, core) root := cores[0].Root wg := sync.WaitGroup{} funcs := []string{"encrypt", "decrypt", "rotate", "change_min_version"} keys := []string{"test1", "test2", "test3"} hosts := []string{ fmt.Sprintf("https://127.0.0.1:%d/v1/transit/", cores[1].Listeners[0].Address.Port), fmt.Sprintf("https://127.0.0.1:%d/v1/transit/", cores[2].Listeners[0].Address.Port), } transport := cleanhttp.DefaultPooledTransport() transport.TLSClientConfig = cores[0].TLSConfig client := &http.Client{ Transport: transport, CheckRedirect: func(*http.Request, []*http.Request) error { return fmt.Errorf("redirects not allowed in this test") }, } //core.Logger().Printf("[TRACE] mounting transit") req, err := http.NewRequest("POST", fmt.Sprintf("https://127.0.0.1:%d/v1/sys/mounts/transit", cores[0].Listeners[0].Address.Port), bytes.NewBuffer([]byte("{\"type\": \"transit\"}"))) if err != nil { t.Fatal(err) } req.Header.Set(AuthHeaderName, root) _, err = client.Do(req) if err != nil { t.Fatal(err) } //core.Logger().Printf("[TRACE] done mounting transit") var totalOps int64 var successfulOps int64 var key1ver int64 = 1 var key2ver int64 = 1 var key3ver int64 = 1 // This is the goroutine loop doFuzzy := func(id int) { // Check for panics, otherwise notify we're done defer func() { if err := recover(); err != nil { core.Logger().Printf("[ERR] got a panic: %v", err) t.Fail() } wg.Done() }() // Holds the latest encrypted value for each key latestEncryptedText := map[string]string{} startTime := time.Now() client := &http.Client{ Transport: transport, } var chosenFunc, chosenKey, chosenHost string doReq := func(method, url string, body io.Reader) (*http.Response, error) { req, err := http.NewRequest(method, url, body) if err != nil { return nil, err } req.Header.Set(AuthHeaderName, root) resp, err := client.Do(req) if err != nil { return nil, err } return resp, nil } doResp := func(resp *http.Response) (*api.Secret, error) { if resp == nil { return nil, fmt.Errorf("nil response") } defer resp.Body.Close() // Make sure we weren't redirected if resp.StatusCode > 300 && resp.StatusCode < 400 { return nil, fmt.Errorf("got status code %d, resp was %#v", resp.StatusCode, *resp) } result := &api.Response{Response: resp} err = result.Error() if err != nil { return nil, err } secret, err := api.ParseSecret(result.Body) if err != nil { return nil, err } return secret, nil } for _, chosenHost := range hosts { for _, chosenKey := range keys { // Try to write the key to make sure it exists _, err := doReq("POST", chosenHost+"keys/"+chosenKey, bytes.NewBuffer([]byte("{}"))) if err != nil { panic(err) } } } //core.Logger().Printf("[TRACE] Starting %d", id) for { // Stop after 10 seconds if time.Now().Sub(startTime) > 10*time.Second { return } atomic.AddInt64(&totalOps, 1) // Pick a function and a key chosenFunc = funcs[rand.Int()%len(funcs)] chosenKey = keys[rand.Int()%len(keys)] chosenHost = hosts[rand.Int()%len(hosts)] switch chosenFunc { // Encrypt our plaintext and store the result case "encrypt": //core.Logger().Printf("[TRACE] %s, %s, %d", chosenFunc, chosenKey, id) resp, err := doReq("POST", chosenHost+"encrypt/"+chosenKey, bytes.NewBuffer([]byte(fmt.Sprintf("{\"plaintext\": \"%s\"}", testPlaintextB64)))) if err != nil { panic(err) } secret, err := doResp(resp) if err != nil { panic(err) } latest := secret.Data["ciphertext"].(string) if latest == "" { panic(fmt.Errorf("bad ciphertext")) } latestEncryptedText[chosenKey] = secret.Data["ciphertext"].(string) atomic.AddInt64(&successfulOps, 1) // Decrypt the ciphertext and compare the result case "decrypt": ct := latestEncryptedText[chosenKey] if ct == "" { atomic.AddInt64(&successfulOps, 1) continue } //core.Logger().Printf("[TRACE] %s, %s, %d", chosenFunc, chosenKey, id) resp, err := doReq("POST", chosenHost+"decrypt/"+chosenKey, bytes.NewBuffer([]byte(fmt.Sprintf("{\"ciphertext\": \"%s\"}", ct)))) if err != nil { panic(err) } secret, err := doResp(resp) if err != nil { // This could well happen since the min version is jumping around if strings.Contains(err.Error(), transit.ErrTooOld) { atomic.AddInt64(&successfulOps, 1) continue } panic(err) } ptb64 := secret.Data["plaintext"].(string) pt, err := base64.StdEncoding.DecodeString(ptb64) if err != nil { panic(fmt.Errorf("got an error decoding base64 plaintext: %v", err)) } if string(pt) != testPlaintext { panic(fmt.Errorf("got bad plaintext back: %s", pt)) } atomic.AddInt64(&successfulOps, 1) // Rotate to a new key version case "rotate": //core.Logger().Printf("[TRACE] %s, %s, %d", chosenFunc, chosenKey, id) _, err := doReq("POST", chosenHost+"keys/"+chosenKey+"/rotate", bytes.NewBuffer([]byte("{}"))) if err != nil { panic(err) } switch chosenKey { case "test1": atomic.AddInt64(&key1ver, 1) case "test2": atomic.AddInt64(&key2ver, 1) case "test3": atomic.AddInt64(&key3ver, 1) } atomic.AddInt64(&successfulOps, 1) // Change the min version, which also tests the archive functionality case "change_min_version": var latestVersion int64 switch chosenKey { case "test1": latestVersion = atomic.LoadInt64(&key1ver) case "test2": latestVersion = atomic.LoadInt64(&key2ver) case "test3": latestVersion = atomic.LoadInt64(&key3ver) } setVersion := (rand.Int63() % latestVersion) + 1 //core.Logger().Printf("[TRACE] %s, %s, %d, new min version %d", chosenFunc, chosenKey, id, setVersion) _, err := doReq("POST", chosenHost+"keys/"+chosenKey+"/config", bytes.NewBuffer([]byte(fmt.Sprintf("{\"min_decryption_version\": %d}", setVersion)))) if err != nil { panic(err) } atomic.AddInt64(&successfulOps, 1) } } } // Spawn 20 of these workers for 10 seconds for i := 0; i < 20; i++ { wg.Add(1) //core.Logger().Printf("[TRACE] spawning %d", i) go doFuzzy(i) } // Wait for them all to finish wg.Wait() core.Logger().Printf("[TRACE] total operations tried: %d, total successful: %d", totalOps, successfulOps) if totalOps != successfulOps { t.Fatalf("total/successful ops mismatch: %d/%d", totalOps, successfulOps) } }
// This tests TLS connection state forwarding by ensuring that we can use a // client TLS to authenticate against the cert backend func TestHTTP_Forwarding_ClientTLS(t *testing.T) { handler1 := http.NewServeMux() handler2 := http.NewServeMux() handler3 := http.NewServeMux() coreConfig := &vault.CoreConfig{ CredentialBackends: map[string]logical.Factory{ "cert": credCert.Factory, }, } // Chicken-and-egg: Handler needs a core. So we create handlers first, then // add routes chained to a Handler-created handler. cores := vault.TestCluster(t, []http.Handler{handler1, handler2, handler3}, coreConfig, true) for _, core := range cores { defer core.CloseListeners() } handler1.Handle("/", Handler(cores[0].Core)) handler2.Handle("/", Handler(cores[1].Core)) handler3.Handle("/", Handler(cores[2].Core)) // make it easy to get access to the active core := cores[0].Core vault.TestWaitActive(t, core) root := cores[0].Root transport := cleanhttp.DefaultTransport() transport.TLSClientConfig = cores[0].TLSConfig client := &http.Client{ Transport: transport, } req, err := http.NewRequest("POST", fmt.Sprintf("https://127.0.0.1:%d/v1/sys/auth/cert", cores[0].Listeners[0].Address.Port), bytes.NewBuffer([]byte("{\"type\": \"cert\"}"))) if err != nil { t.Fatal(err) } req.Header.Set(AuthHeaderName, root) _, err = client.Do(req) if err != nil { t.Fatal(err) } type certConfig struct { Certificate string `json:"certificate"` Policies string `json:"policies"` } encodedCertConfig, err := json.Marshal(&certConfig{ Certificate: vault.TestClusterCACert, Policies: "default", }) if err != nil { t.Fatal(err) } req, err = http.NewRequest("POST", fmt.Sprintf("https://127.0.0.1:%d/v1/auth/cert/certs/test", cores[0].Listeners[0].Address.Port), bytes.NewBuffer(encodedCertConfig)) if err != nil { t.Fatal(err) } req.Header.Set(AuthHeaderName, root) _, err = client.Do(req) if err != nil { t.Fatal(err) } addrs := []string{ fmt.Sprintf("https://127.0.0.1:%d", cores[1].Listeners[0].Address.Port), fmt.Sprintf("https://127.0.0.1:%d", cores[2].Listeners[0].Address.Port), } // Ensure we can't possibly use lingering connections even though it should be to a different address transport = cleanhttp.DefaultTransport() transport.TLSClientConfig = cores[0].TLSConfig client = &http.Client{ Transport: transport, CheckRedirect: func(*http.Request, []*http.Request) error { return fmt.Errorf("redirects not allowed in this test") }, } //cores[0].Logger().Printf("root token is %s", root) //time.Sleep(4 * time.Hour) for _, addr := range addrs { config := api.DefaultConfig() config.Address = addr config.HttpClient = client client, err := api.NewClient(config) if err != nil { t.Fatal(err) } secret, err := client.Logical().Write("auth/cert/login", nil) if err != nil { t.Fatal(err) } if secret == nil { t.Fatal("secret is nil") } if secret.Auth == nil { t.Fatal("auth is nil") } if secret.Auth.Policies == nil || len(secret.Auth.Policies) == 0 || secret.Auth.Policies[0] != "default" { t.Fatalf("bad policies: %#v", secret.Auth.Policies) } } }
// Test wrapping functionality func TestHTTP_Wrapping(t *testing.T) { handler1 := http.NewServeMux() handler2 := http.NewServeMux() handler3 := http.NewServeMux() coreConfig := &vault.CoreConfig{} // Chicken-and-egg: Handler needs a core. So we create handlers first, then // add routes chained to a Handler-created handler. cores := vault.TestCluster(t, []http.Handler{handler1, handler2, handler3}, coreConfig, true) for _, core := range cores { defer core.CloseListeners() } handler1.Handle("/", Handler(cores[0].Core)) handler2.Handle("/", Handler(cores[1].Core)) handler3.Handle("/", Handler(cores[2].Core)) // make it easy to get access to the active core := cores[0].Core vault.TestWaitActive(t, core) root := cores[0].Root transport := cleanhttp.DefaultTransport() transport.TLSClientConfig = cores[0].TLSConfig httpClient := &http.Client{ Transport: transport, } addr := fmt.Sprintf("https://127.0.0.1:%d", cores[0].Listeners[0].Address.Port) config := api.DefaultConfig() config.Address = addr config.HttpClient = httpClient client, err := api.NewClient(config) if err != nil { t.Fatal(err) } client.SetToken(root) // Write a value that we will use with wrapping for lookup _, err = client.Logical().Write("secret/foo", map[string]interface{}{ "zip": "zap", }) if err != nil { t.Fatal(err) } // Set a wrapping lookup function for reads on that path client.SetWrappingLookupFunc(func(operation, path string) string { if operation == "GET" && path == "secret/foo" { return "5m" } return api.DefaultWrappingLookupFunc(operation, path) }) // First test: basic things that should fail, lookup edition // Root token isn't a wrapping token _, err = client.Logical().Write("sys/wrapping/lookup", nil) if err == nil { t.Fatal("expected error") } // Not supplied _, err = client.Logical().Write("sys/wrapping/lookup", map[string]interface{}{ "foo": "bar", }) if err == nil { t.Fatal("expected error") } // Nonexistent token isn't a wrapping token _, err = client.Logical().Write("sys/wrapping/lookup", map[string]interface{}{ "token": "bar", }) if err == nil { t.Fatal("expected error") } // Second: basic things that should fail, unwrap edition // Root token isn't a wrapping token _, err = client.Logical().Unwrap(root) if err == nil { t.Fatal("expected error") } // Root token isn't a wrapping token _, err = client.Logical().Write("sys/wrapping/unwrap", nil) if err == nil { t.Fatal("expected error") } // Not supplied _, err = client.Logical().Write("sys/wrapping/unwrap", map[string]interface{}{ "foo": "bar", }) if err == nil { t.Fatal("expected error") } // Nonexistent token isn't a wrapping token _, err = client.Logical().Write("sys/wrapping/unwrap", map[string]interface{}{ "token": "bar", }) if err == nil { t.Fatal("expected error") } // // Test lookup // // Create a wrapping token secret, err := client.Logical().Read("secret/foo") if err != nil { t.Fatal(err) } if secret == nil || secret.WrapInfo == nil { t.Fatal("secret or wrap info is nil") } wrapInfo := secret.WrapInfo // Test this twice to ensure no ill effect to the wrapping token as a result of the lookup for i := 0; i < 2; i++ { secret, err = client.Logical().Write("sys/wrapping/lookup", map[string]interface{}{ "token": wrapInfo.Token, }) if secret == nil || secret.Data == nil { t.Fatal("secret or secret data is nil") } creationTTL, _ := secret.Data["creation_ttl"].(json.Number).Int64() if int(creationTTL) != wrapInfo.TTL { t.Fatalf("mistmatched ttls: %d vs %d", creationTTL, wrapInfo.TTL) } if secret.Data["creation_time"].(string) != wrapInfo.CreationTime.Format(time.RFC3339Nano) { t.Fatalf("mistmatched creation times: %d vs %d", secret.Data["creation_time"].(string), wrapInfo.CreationTime.Format(time.RFC3339Nano)) } } // // Test unwrap // // Create a wrapping token secret, err = client.Logical().Read("secret/foo") if err != nil { t.Fatal(err) } if secret == nil || secret.WrapInfo == nil { t.Fatal("secret or wrap info is nil") } wrapInfo = secret.WrapInfo // Test unwrap via the client token client.SetToken(wrapInfo.Token) secret, err = client.Logical().Write("sys/wrapping/unwrap", nil) if secret == nil || secret.Data == nil { t.Fatal("secret or secret data is nil") } ret1 := secret // Should be expired and fail _, err = client.Logical().Write("sys/wrapping/unwrap", nil) if err == nil { t.Fatal("expected err") } // Create a wrapping token client.SetToken(root) secret, err = client.Logical().Read("secret/foo") if err != nil { t.Fatal(err) } if secret == nil || secret.WrapInfo == nil { t.Fatal("secret or wrap info is nil") } wrapInfo = secret.WrapInfo // Test as a separate token secret, err = client.Logical().Write("sys/wrapping/unwrap", map[string]interface{}{ "token": wrapInfo.Token, }) ret2 := secret // Should be expired and fail _, err = client.Logical().Write("sys/wrapping/unwrap", map[string]interface{}{ "token": wrapInfo.Token, }) if err == nil { t.Fatal("expected err") } // Create a wrapping token secret, err = client.Logical().Read("secret/foo") if err != nil { t.Fatal(err) } if secret == nil || secret.WrapInfo == nil { t.Fatal("secret or wrap info is nil") } wrapInfo = secret.WrapInfo // Read response directly client.SetToken(wrapInfo.Token) secret, err = client.Logical().Read("cubbyhole/response") ret3 := secret // Should be expired and fail _, err = client.Logical().Write("cubbyhole/response", nil) if err == nil { t.Fatal("expected err") } // Create a wrapping token client.SetToken(root) secret, err = client.Logical().Read("secret/foo") if err != nil { t.Fatal(err) } if secret == nil || secret.WrapInfo == nil { t.Fatal("secret or wrap info is nil") } wrapInfo = secret.WrapInfo // Read via Unwrap method secret, err = client.Logical().Unwrap(wrapInfo.Token) ret4 := secret // Should be expired and fail _, err = client.Logical().Unwrap(wrapInfo.Token) if err == nil { t.Fatal("expected err") } if !reflect.DeepEqual(ret1.Data, map[string]interface{}{ "zip": "zap", }) { t.Fatalf("ret1 data did not match expected: %#v", ret1.Data) } if !reflect.DeepEqual(ret2.Data, map[string]interface{}{ "zip": "zap", }) { t.Fatalf("ret2 data did not match expected: %#v", ret2.Data) } var ret3Secret api.Secret err = jsonutil.DecodeJSON([]byte(ret3.Data["response"].(string)), &ret3Secret) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(ret3Secret.Data, map[string]interface{}{ "zip": "zap", }) { t.Fatalf("ret3 data did not match expected: %#v", ret3Secret.Data) } if !reflect.DeepEqual(ret4.Data, map[string]interface{}{ "zip": "zap", }) { t.Fatalf("ret4 data did not match expected: %#v", ret4.Data) } // // Custom wrapping // client.SetToken(root) data := map[string]interface{}{ "zip": "zap", "three": json.Number("2"), } // Don't set a request TTL on that path, should fail client.SetWrappingLookupFunc(func(operation, path string) string { return "" }) secret, err = client.Logical().Write("sys/wrapping/wrap", data) if err == nil { t.Fatal("expected error") } // Re-set the lookup function client.SetWrappingLookupFunc(func(operation, path string) string { if operation == "GET" && path == "secret/foo" { return "5m" } return api.DefaultWrappingLookupFunc(operation, path) }) secret, err = client.Logical().Write("sys/wrapping/wrap", data) if err != nil { t.Fatal(err) } secret, err = client.Logical().Unwrap(secret.WrapInfo.Token) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(data, secret.Data) { t.Fatal("custom wrap did not match expected: %#v", secret.Data) } // // Test rewrap // // Create a wrapping token secret, err = client.Logical().Read("secret/foo") if err != nil { t.Fatal(err) } if secret == nil || secret.WrapInfo == nil { t.Fatal("secret or wrap info is nil") } wrapInfo = secret.WrapInfo // Test rewrapping secret, err = client.Logical().Write("sys/wrapping/rewrap", map[string]interface{}{ "token": wrapInfo.Token, }) // Should be expired and fail _, err = client.Logical().Write("sys/wrapping/unwrap", map[string]interface{}{ "token": wrapInfo.Token, }) if err == nil { t.Fatal("expected err") } // Attempt unwrapping the rewrapped token wrapToken := secret.WrapInfo.Token secret, err = client.Logical().Unwrap(wrapToken) if err != nil { t.Fatal(err) } // Should be expired and fail _, err = client.Logical().Unwrap(wrapToken) if err == nil { t.Fatal("expected err") } if !reflect.DeepEqual(secret.Data, map[string]interface{}{ "zip": "zap", }) { t.Fatalf("secret data did not match expected: %#v", secret.Data) } }