Example #1
0
func TestHTTP_Fallback_Disabled(t *testing.T) {
	handler1 := http.NewServeMux()
	handler2 := http.NewServeMux()
	handler3 := http.NewServeMux()

	coreConfig := &vault.CoreConfig{
		LogicalBackends: map[string]logical.Factory{
			"transit": transit.Factory,
		},
		ClusterAddr: "empty",
	}

	// Chicken-and-egg: Handler needs a core. So we create handlers first, then
	// add routes chained to a Handler-created handler.
	cores := vault.TestCluster(t, []http.Handler{handler1, handler2, handler3}, coreConfig, true)
	for _, core := range cores {
		defer core.CloseListeners()
	}
	handler1.Handle("/", Handler(cores[0].Core))
	handler2.Handle("/", Handler(cores[1].Core))
	handler3.Handle("/", Handler(cores[2].Core))

	// make it easy to get access to the active
	core := cores[0].Core
	vault.TestWaitActive(t, core)

	root := cores[0].Root

	addrs := []string{
		fmt.Sprintf("https://127.0.0.1:%d", cores[1].Listeners[0].Address.Port),
		fmt.Sprintf("https://127.0.0.1:%d", cores[2].Listeners[0].Address.Port),
	}

	for _, addr := range addrs {
		config := api.DefaultConfig()
		config.Address = addr
		config.HttpClient = cleanhttp.DefaultClient()
		config.HttpClient.Transport.(*http.Transport).TLSClientConfig = cores[0].TLSConfig
		client, err := api.NewClient(config)
		if err != nil {
			t.Fatal(err)
		}
		client.SetToken(root)

		secret, err := client.Auth().Token().LookupSelf()
		if err != nil {
			t.Fatal(err)
		}
		if secret == nil {
			t.Fatal("secret is nil")
		}
		if secret.Data["id"].(string) != root {
			t.Fatal("token mismatch")
		}
	}
}
Example #2
0
// This function recreates the fuzzy testing from transit to pipe a large
// number of requests from the standbys to the active node.
func TestHTTP_Forwarding_Stress(t *testing.T) {
	testPlaintext := "the quick brown fox"
	testPlaintextB64 := "dGhlIHF1aWNrIGJyb3duIGZveA=="

	handler1 := http.NewServeMux()
	handler2 := http.NewServeMux()
	handler3 := http.NewServeMux()

	coreConfig := &vault.CoreConfig{
		LogicalBackends: map[string]logical.Factory{
			"transit": transit.Factory,
		},
	}

	// Chicken-and-egg: Handler needs a core. So we create handlers first, then
	// add routes chained to a Handler-created handler.
	cores := vault.TestCluster(t, []http.Handler{handler1, handler2, handler3}, coreConfig, true)
	for _, core := range cores {
		defer core.CloseListeners()
	}
	handler1.Handle("/", Handler(cores[0].Core))
	handler2.Handle("/", Handler(cores[1].Core))
	handler3.Handle("/", Handler(cores[2].Core))

	// make it easy to get access to the active
	core := cores[0].Core
	vault.TestWaitActive(t, core)

	root := cores[0].Root

	wg := sync.WaitGroup{}

	funcs := []string{"encrypt", "decrypt", "rotate", "change_min_version"}
	keys := []string{"test1", "test2", "test3"}

	hosts := []string{
		fmt.Sprintf("https://127.0.0.1:%d/v1/transit/", cores[1].Listeners[0].Address.Port),
		fmt.Sprintf("https://127.0.0.1:%d/v1/transit/", cores[2].Listeners[0].Address.Port),
	}

	transport := cleanhttp.DefaultPooledTransport()
	transport.TLSClientConfig = cores[0].TLSConfig

	client := &http.Client{
		Transport: transport,
		CheckRedirect: func(*http.Request, []*http.Request) error {
			return fmt.Errorf("redirects not allowed in this test")
		},
	}

	//core.Logger().Printf("[TRACE] mounting transit")
	req, err := http.NewRequest("POST", fmt.Sprintf("https://127.0.0.1:%d/v1/sys/mounts/transit", cores[0].Listeners[0].Address.Port),
		bytes.NewBuffer([]byte("{\"type\": \"transit\"}")))
	if err != nil {
		t.Fatal(err)
	}
	req.Header.Set(AuthHeaderName, root)
	_, err = client.Do(req)
	if err != nil {
		t.Fatal(err)
	}
	//core.Logger().Printf("[TRACE] done mounting transit")

	var totalOps int64
	var successfulOps int64
	var key1ver int64 = 1
	var key2ver int64 = 1
	var key3ver int64 = 1

	// This is the goroutine loop
	doFuzzy := func(id int) {
		// Check for panics, otherwise notify we're done
		defer func() {
			if err := recover(); err != nil {
				core.Logger().Printf("[ERR] got a panic: %v", err)
				t.Fail()
			}
			wg.Done()
		}()

		// Holds the latest encrypted value for each key
		latestEncryptedText := map[string]string{}

		startTime := time.Now()
		client := &http.Client{
			Transport: transport,
		}

		var chosenFunc, chosenKey, chosenHost string

		doReq := func(method, url string, body io.Reader) (*http.Response, error) {
			req, err := http.NewRequest(method, url, body)
			if err != nil {
				return nil, err
			}
			req.Header.Set(AuthHeaderName, root)
			resp, err := client.Do(req)
			if err != nil {
				return nil, err
			}
			return resp, nil
		}

		doResp := func(resp *http.Response) (*api.Secret, error) {
			if resp == nil {
				return nil, fmt.Errorf("nil response")
			}
			defer resp.Body.Close()

			// Make sure we weren't redirected
			if resp.StatusCode > 300 && resp.StatusCode < 400 {
				return nil, fmt.Errorf("got status code %d, resp was %#v", resp.StatusCode, *resp)
			}

			result := &api.Response{Response: resp}
			err = result.Error()
			if err != nil {
				return nil, err
			}

			secret, err := api.ParseSecret(result.Body)
			if err != nil {
				return nil, err
			}

			return secret, nil
		}

		for _, chosenHost := range hosts {
			for _, chosenKey := range keys {
				// Try to write the key to make sure it exists
				_, err := doReq("POST", chosenHost+"keys/"+chosenKey, bytes.NewBuffer([]byte("{}")))
				if err != nil {
					panic(err)
				}
			}
		}

		//core.Logger().Printf("[TRACE] Starting %d", id)
		for {
			// Stop after 10 seconds
			if time.Now().Sub(startTime) > 10*time.Second {
				return
			}

			atomic.AddInt64(&totalOps, 1)

			// Pick a function and a key
			chosenFunc = funcs[rand.Int()%len(funcs)]
			chosenKey = keys[rand.Int()%len(keys)]
			chosenHost = hosts[rand.Int()%len(hosts)]

			switch chosenFunc {
			// Encrypt our plaintext and store the result
			case "encrypt":
				//core.Logger().Printf("[TRACE] %s, %s, %d", chosenFunc, chosenKey, id)
				resp, err := doReq("POST", chosenHost+"encrypt/"+chosenKey, bytes.NewBuffer([]byte(fmt.Sprintf("{\"plaintext\": \"%s\"}", testPlaintextB64))))
				if err != nil {
					panic(err)
				}

				secret, err := doResp(resp)
				if err != nil {
					panic(err)
				}

				latest := secret.Data["ciphertext"].(string)
				if latest == "" {
					panic(fmt.Errorf("bad ciphertext"))
				}
				latestEncryptedText[chosenKey] = secret.Data["ciphertext"].(string)

				atomic.AddInt64(&successfulOps, 1)

			// Decrypt the ciphertext and compare the result
			case "decrypt":
				ct := latestEncryptedText[chosenKey]
				if ct == "" {
					atomic.AddInt64(&successfulOps, 1)
					continue
				}

				//core.Logger().Printf("[TRACE] %s, %s, %d", chosenFunc, chosenKey, id)
				resp, err := doReq("POST", chosenHost+"decrypt/"+chosenKey, bytes.NewBuffer([]byte(fmt.Sprintf("{\"ciphertext\": \"%s\"}", ct))))
				if err != nil {
					panic(err)
				}

				secret, err := doResp(resp)
				if err != nil {
					// This could well happen since the min version is jumping around
					if strings.Contains(err.Error(), transit.ErrTooOld) {
						atomic.AddInt64(&successfulOps, 1)
						continue
					}
					panic(err)
				}

				ptb64 := secret.Data["plaintext"].(string)
				pt, err := base64.StdEncoding.DecodeString(ptb64)
				if err != nil {
					panic(fmt.Errorf("got an error decoding base64 plaintext: %v", err))
				}
				if string(pt) != testPlaintext {
					panic(fmt.Errorf("got bad plaintext back: %s", pt))
				}

				atomic.AddInt64(&successfulOps, 1)

			// Rotate to a new key version
			case "rotate":
				//core.Logger().Printf("[TRACE] %s, %s, %d", chosenFunc, chosenKey, id)
				_, err := doReq("POST", chosenHost+"keys/"+chosenKey+"/rotate", bytes.NewBuffer([]byte("{}")))
				if err != nil {
					panic(err)
				}
				switch chosenKey {
				case "test1":
					atomic.AddInt64(&key1ver, 1)
				case "test2":
					atomic.AddInt64(&key2ver, 1)
				case "test3":
					atomic.AddInt64(&key3ver, 1)
				}
				atomic.AddInt64(&successfulOps, 1)

			// Change the min version, which also tests the archive functionality
			case "change_min_version":
				var latestVersion int64
				switch chosenKey {
				case "test1":
					latestVersion = atomic.LoadInt64(&key1ver)
				case "test2":
					latestVersion = atomic.LoadInt64(&key2ver)
				case "test3":
					latestVersion = atomic.LoadInt64(&key3ver)
				}

				setVersion := (rand.Int63() % latestVersion) + 1

				//core.Logger().Printf("[TRACE] %s, %s, %d, new min version %d", chosenFunc, chosenKey, id, setVersion)

				_, err := doReq("POST", chosenHost+"keys/"+chosenKey+"/config", bytes.NewBuffer([]byte(fmt.Sprintf("{\"min_decryption_version\": %d}", setVersion))))
				if err != nil {
					panic(err)
				}

				atomic.AddInt64(&successfulOps, 1)
			}
		}
	}

	// Spawn 20 of these workers for 10 seconds
	for i := 0; i < 20; i++ {
		wg.Add(1)
		//core.Logger().Printf("[TRACE] spawning %d", i)
		go doFuzzy(i)
	}

	// Wait for them all to finish
	wg.Wait()

	core.Logger().Printf("[TRACE] total operations tried: %d, total successful: %d", totalOps, successfulOps)
	if totalOps != successfulOps {
		t.Fatalf("total/successful ops mismatch: %d/%d", totalOps, successfulOps)
	}
}
Example #3
0
// This tests TLS connection state forwarding by ensuring that we can use a
// client TLS to authenticate against the cert backend
func TestHTTP_Forwarding_ClientTLS(t *testing.T) {
	handler1 := http.NewServeMux()
	handler2 := http.NewServeMux()
	handler3 := http.NewServeMux()

	coreConfig := &vault.CoreConfig{
		CredentialBackends: map[string]logical.Factory{
			"cert": credCert.Factory,
		},
	}

	// Chicken-and-egg: Handler needs a core. So we create handlers first, then
	// add routes chained to a Handler-created handler.
	cores := vault.TestCluster(t, []http.Handler{handler1, handler2, handler3}, coreConfig, true)
	for _, core := range cores {
		defer core.CloseListeners()
	}
	handler1.Handle("/", Handler(cores[0].Core))
	handler2.Handle("/", Handler(cores[1].Core))
	handler3.Handle("/", Handler(cores[2].Core))

	// make it easy to get access to the active
	core := cores[0].Core
	vault.TestWaitActive(t, core)

	root := cores[0].Root

	transport := cleanhttp.DefaultTransport()
	transport.TLSClientConfig = cores[0].TLSConfig

	client := &http.Client{
		Transport: transport,
	}

	req, err := http.NewRequest("POST", fmt.Sprintf("https://127.0.0.1:%d/v1/sys/auth/cert", cores[0].Listeners[0].Address.Port),
		bytes.NewBuffer([]byte("{\"type\": \"cert\"}")))
	if err != nil {
		t.Fatal(err)
	}
	req.Header.Set(AuthHeaderName, root)
	_, err = client.Do(req)
	if err != nil {
		t.Fatal(err)
	}

	type certConfig struct {
		Certificate string `json:"certificate"`
		Policies    string `json:"policies"`
	}
	encodedCertConfig, err := json.Marshal(&certConfig{
		Certificate: vault.TestClusterCACert,
		Policies:    "default",
	})
	if err != nil {
		t.Fatal(err)
	}
	req, err = http.NewRequest("POST", fmt.Sprintf("https://127.0.0.1:%d/v1/auth/cert/certs/test", cores[0].Listeners[0].Address.Port),
		bytes.NewBuffer(encodedCertConfig))
	if err != nil {
		t.Fatal(err)
	}
	req.Header.Set(AuthHeaderName, root)
	_, err = client.Do(req)
	if err != nil {
		t.Fatal(err)
	}

	addrs := []string{
		fmt.Sprintf("https://127.0.0.1:%d", cores[1].Listeners[0].Address.Port),
		fmt.Sprintf("https://127.0.0.1:%d", cores[2].Listeners[0].Address.Port),
	}

	// Ensure we can't possibly use lingering connections even though it should be to a different address

	transport = cleanhttp.DefaultTransport()
	transport.TLSClientConfig = cores[0].TLSConfig

	client = &http.Client{
		Transport: transport,
		CheckRedirect: func(*http.Request, []*http.Request) error {
			return fmt.Errorf("redirects not allowed in this test")
		},
	}

	//cores[0].Logger().Printf("root token is %s", root)
	//time.Sleep(4 * time.Hour)

	for _, addr := range addrs {
		config := api.DefaultConfig()
		config.Address = addr
		config.HttpClient = client
		client, err := api.NewClient(config)
		if err != nil {
			t.Fatal(err)
		}

		secret, err := client.Logical().Write("auth/cert/login", nil)
		if err != nil {
			t.Fatal(err)
		}
		if secret == nil {
			t.Fatal("secret is nil")
		}
		if secret.Auth == nil {
			t.Fatal("auth is nil")
		}
		if secret.Auth.Policies == nil || len(secret.Auth.Policies) == 0 || secret.Auth.Policies[0] != "default" {
			t.Fatalf("bad policies: %#v", secret.Auth.Policies)
		}
	}
}
Example #4
0
// Test wrapping functionality
func TestHTTP_Wrapping(t *testing.T) {
	handler1 := http.NewServeMux()
	handler2 := http.NewServeMux()
	handler3 := http.NewServeMux()

	coreConfig := &vault.CoreConfig{}

	// Chicken-and-egg: Handler needs a core. So we create handlers first, then
	// add routes chained to a Handler-created handler.
	cores := vault.TestCluster(t, []http.Handler{handler1, handler2, handler3}, coreConfig, true)
	for _, core := range cores {
		defer core.CloseListeners()
	}
	handler1.Handle("/", Handler(cores[0].Core))
	handler2.Handle("/", Handler(cores[1].Core))
	handler3.Handle("/", Handler(cores[2].Core))

	// make it easy to get access to the active
	core := cores[0].Core
	vault.TestWaitActive(t, core)

	root := cores[0].Root

	transport := cleanhttp.DefaultTransport()
	transport.TLSClientConfig = cores[0].TLSConfig
	httpClient := &http.Client{
		Transport: transport,
	}
	addr := fmt.Sprintf("https://127.0.0.1:%d", cores[0].Listeners[0].Address.Port)
	config := api.DefaultConfig()
	config.Address = addr
	config.HttpClient = httpClient
	client, err := api.NewClient(config)
	if err != nil {
		t.Fatal(err)
	}
	client.SetToken(root)

	// Write a value that we will use with wrapping for lookup
	_, err = client.Logical().Write("secret/foo", map[string]interface{}{
		"zip": "zap",
	})
	if err != nil {
		t.Fatal(err)
	}

	// Set a wrapping lookup function for reads on that path
	client.SetWrappingLookupFunc(func(operation, path string) string {
		if operation == "GET" && path == "secret/foo" {
			return "5m"
		}

		return api.DefaultWrappingLookupFunc(operation, path)
	})

	// First test: basic things that should fail, lookup edition
	// Root token isn't a wrapping token
	_, err = client.Logical().Write("sys/wrapping/lookup", nil)
	if err == nil {
		t.Fatal("expected error")
	}
	// Not supplied
	_, err = client.Logical().Write("sys/wrapping/lookup", map[string]interface{}{
		"foo": "bar",
	})
	if err == nil {
		t.Fatal("expected error")
	}
	// Nonexistent token isn't a wrapping token
	_, err = client.Logical().Write("sys/wrapping/lookup", map[string]interface{}{
		"token": "bar",
	})
	if err == nil {
		t.Fatal("expected error")
	}

	// Second: basic things that should fail, unwrap edition
	// Root token isn't a wrapping token
	_, err = client.Logical().Unwrap(root)
	if err == nil {
		t.Fatal("expected error")
	}
	// Root token isn't a wrapping token
	_, err = client.Logical().Write("sys/wrapping/unwrap", nil)
	if err == nil {
		t.Fatal("expected error")
	}
	// Not supplied
	_, err = client.Logical().Write("sys/wrapping/unwrap", map[string]interface{}{
		"foo": "bar",
	})
	if err == nil {
		t.Fatal("expected error")
	}
	// Nonexistent token isn't a wrapping token
	_, err = client.Logical().Write("sys/wrapping/unwrap", map[string]interface{}{
		"token": "bar",
	})
	if err == nil {
		t.Fatal("expected error")
	}

	//
	// Test lookup
	//

	// Create a wrapping token
	secret, err := client.Logical().Read("secret/foo")
	if err != nil {
		t.Fatal(err)
	}
	if secret == nil || secret.WrapInfo == nil {
		t.Fatal("secret or wrap info is nil")
	}
	wrapInfo := secret.WrapInfo

	// Test this twice to ensure no ill effect to the wrapping token as a result of the lookup
	for i := 0; i < 2; i++ {
		secret, err = client.Logical().Write("sys/wrapping/lookup", map[string]interface{}{
			"token": wrapInfo.Token,
		})
		if secret == nil || secret.Data == nil {
			t.Fatal("secret or secret data is nil")
		}
		creationTTL, _ := secret.Data["creation_ttl"].(json.Number).Int64()
		if int(creationTTL) != wrapInfo.TTL {
			t.Fatalf("mistmatched ttls: %d vs %d", creationTTL, wrapInfo.TTL)
		}
		if secret.Data["creation_time"].(string) != wrapInfo.CreationTime.Format(time.RFC3339Nano) {
			t.Fatalf("mistmatched creation times: %d vs %d", secret.Data["creation_time"].(string), wrapInfo.CreationTime.Format(time.RFC3339Nano))
		}
	}

	//
	// Test unwrap
	//

	// Create a wrapping token
	secret, err = client.Logical().Read("secret/foo")
	if err != nil {
		t.Fatal(err)
	}
	if secret == nil || secret.WrapInfo == nil {
		t.Fatal("secret or wrap info is nil")
	}
	wrapInfo = secret.WrapInfo

	// Test unwrap via the client token
	client.SetToken(wrapInfo.Token)
	secret, err = client.Logical().Write("sys/wrapping/unwrap", nil)
	if secret == nil || secret.Data == nil {
		t.Fatal("secret or secret data is nil")
	}
	ret1 := secret
	// Should be expired and fail
	_, err = client.Logical().Write("sys/wrapping/unwrap", nil)
	if err == nil {
		t.Fatal("expected err")
	}

	// Create a wrapping token
	client.SetToken(root)
	secret, err = client.Logical().Read("secret/foo")
	if err != nil {
		t.Fatal(err)
	}
	if secret == nil || secret.WrapInfo == nil {
		t.Fatal("secret or wrap info is nil")
	}
	wrapInfo = secret.WrapInfo

	// Test as a separate token
	secret, err = client.Logical().Write("sys/wrapping/unwrap", map[string]interface{}{
		"token": wrapInfo.Token,
	})
	ret2 := secret
	// Should be expired and fail
	_, err = client.Logical().Write("sys/wrapping/unwrap", map[string]interface{}{
		"token": wrapInfo.Token,
	})
	if err == nil {
		t.Fatal("expected err")
	}

	// Create a wrapping token
	secret, err = client.Logical().Read("secret/foo")
	if err != nil {
		t.Fatal(err)
	}
	if secret == nil || secret.WrapInfo == nil {
		t.Fatal("secret or wrap info is nil")
	}
	wrapInfo = secret.WrapInfo

	// Read response directly
	client.SetToken(wrapInfo.Token)
	secret, err = client.Logical().Read("cubbyhole/response")
	ret3 := secret
	// Should be expired and fail
	_, err = client.Logical().Write("cubbyhole/response", nil)
	if err == nil {
		t.Fatal("expected err")
	}

	// Create a wrapping token
	client.SetToken(root)
	secret, err = client.Logical().Read("secret/foo")
	if err != nil {
		t.Fatal(err)
	}
	if secret == nil || secret.WrapInfo == nil {
		t.Fatal("secret or wrap info is nil")
	}
	wrapInfo = secret.WrapInfo

	// Read via Unwrap method
	secret, err = client.Logical().Unwrap(wrapInfo.Token)
	ret4 := secret
	// Should be expired and fail
	_, err = client.Logical().Unwrap(wrapInfo.Token)
	if err == nil {
		t.Fatal("expected err")
	}

	if !reflect.DeepEqual(ret1.Data, map[string]interface{}{
		"zip": "zap",
	}) {
		t.Fatalf("ret1 data did not match expected: %#v", ret1.Data)
	}
	if !reflect.DeepEqual(ret2.Data, map[string]interface{}{
		"zip": "zap",
	}) {
		t.Fatalf("ret2 data did not match expected: %#v", ret2.Data)
	}
	var ret3Secret api.Secret
	err = jsonutil.DecodeJSON([]byte(ret3.Data["response"].(string)), &ret3Secret)
	if err != nil {
		t.Fatal(err)
	}
	if !reflect.DeepEqual(ret3Secret.Data, map[string]interface{}{
		"zip": "zap",
	}) {
		t.Fatalf("ret3 data did not match expected: %#v", ret3Secret.Data)
	}
	if !reflect.DeepEqual(ret4.Data, map[string]interface{}{
		"zip": "zap",
	}) {
		t.Fatalf("ret4 data did not match expected: %#v", ret4.Data)
	}

	//
	// Custom wrapping
	//

	client.SetToken(root)
	data := map[string]interface{}{
		"zip":   "zap",
		"three": json.Number("2"),
	}

	// Don't set a request TTL on that path, should fail
	client.SetWrappingLookupFunc(func(operation, path string) string {
		return ""
	})
	secret, err = client.Logical().Write("sys/wrapping/wrap", data)
	if err == nil {
		t.Fatal("expected error")
	}

	// Re-set the lookup function
	client.SetWrappingLookupFunc(func(operation, path string) string {
		if operation == "GET" && path == "secret/foo" {
			return "5m"
		}

		return api.DefaultWrappingLookupFunc(operation, path)
	})
	secret, err = client.Logical().Write("sys/wrapping/wrap", data)
	if err != nil {
		t.Fatal(err)
	}
	secret, err = client.Logical().Unwrap(secret.WrapInfo.Token)
	if err != nil {
		t.Fatal(err)
	}
	if !reflect.DeepEqual(data, secret.Data) {
		t.Fatal("custom wrap did not match expected: %#v", secret.Data)
	}

	//
	// Test rewrap
	//

	// Create a wrapping token
	secret, err = client.Logical().Read("secret/foo")
	if err != nil {
		t.Fatal(err)
	}
	if secret == nil || secret.WrapInfo == nil {
		t.Fatal("secret or wrap info is nil")
	}
	wrapInfo = secret.WrapInfo

	// Test rewrapping
	secret, err = client.Logical().Write("sys/wrapping/rewrap", map[string]interface{}{
		"token": wrapInfo.Token,
	})
	// Should be expired and fail
	_, err = client.Logical().Write("sys/wrapping/unwrap", map[string]interface{}{
		"token": wrapInfo.Token,
	})
	if err == nil {
		t.Fatal("expected err")
	}

	// Attempt unwrapping the rewrapped token
	wrapToken := secret.WrapInfo.Token
	secret, err = client.Logical().Unwrap(wrapToken)
	if err != nil {
		t.Fatal(err)
	}
	// Should be expired and fail
	_, err = client.Logical().Unwrap(wrapToken)
	if err == nil {
		t.Fatal("expected err")
	}

	if !reflect.DeepEqual(secret.Data, map[string]interface{}{
		"zip": "zap",
	}) {
		t.Fatalf("secret data did not match expected: %#v", secret.Data)
	}
}