예제 #1
0
func TestDelete(t *testing.T) {
	w := httptest.NewRecorder()

	m := mux.NewRouter()
	addRoutes(m)

	key := new(secrets.Key)
	err := key.New("968cd432-c97a-11e5-9956-625662870761")
	priv := key.Display()
	assert.Nil(t, err, "Should not return error")

	testDb := new(mocks.DB)

	r, err := http.NewRequest("DELETE", "/secrets/delete/secrets/test-secret", nil)
	assert.Nil(t, err, "Should not return error")

	secrets.New("test-secret", []byte("test"))

	authSetup(testDb, r, priv)

	testDb.On("GetRootSecret", &secrets.Secret{Name: "test-secret"}).Run(func(args mock.Arguments) {
		args.Get(0).(*secrets.Secret).Name = "test-secret"
		args.Get(0).(*secrets.Secret).ID = 2
	}).Return(nil)
	testDb.On("DeleteSecret", &secrets.Secret{ID: 2, Name: "test-secret"}).Return(nil)

	database = testDb

	m.ServeHTTP(w, r)
}
예제 #2
0
func (a *api) auth() bool {
	var err error

	k := new(secrets.Key)

	var secretKey string

	// Grab the credentials, look in the header first and fall back to the query string.
	if k.Name = a.req.Header.Get("X-Secret-ID"); k.Name == "" {
		k.Name = a.req.FormValue("secretid")
	}
	if secretKey = a.req.Header.Get("X-Secret-Key"); secretKey == "" {
		secretKey = a.req.FormValue("secretkey")
	}

	// If the master key has been used then just check the key, else check both.
	if k.Name == secrets.MasterKeyName {
		if secretKeyRegex.MatchString(secretKey) != true {
			log.Error("Invalid auth credential format.")
			return false
		}
	} else if secretIDRegex.MatchString(k.Name) != true || secretKeyRegex.MatchString(secretKey) != true {
		log.Error("Invalid auth credential format.")
		return false
	}

	a.keyID = k.Name
	a.key, err = base64.StdEncoding.DecodeString(
		secretKey)
	if err != nil {
		return false
	}

	priv := new([32]byte)
	pub := new([32]byte)

	copy(priv[:], a.key)
	defer secrets.Zero(priv[:])

	err = database.GetKey(k)
	if err != nil {
		return false
	}

	if !k.ReadOnly {
		a.admin = true
	}

	curve25519.ScalarBaseMult(pub, priv)
	if subtle.ConstantTimeCompare(pub[:], k.Public) == 1 {
		return true
	}
	return false
}
예제 #3
0
func TestListSecret(t *testing.T) {
	w := httptest.NewRecorder()

	m := mux.NewRouter()
	addRoutes(m)

	secretList := make([]secrets.Secret, 20)

	for i := range secretList {
		s, err := secrets.New(fmt.Sprintf("secret-%d", i), []byte("testmessage"))
		secretList[i] = *s
		assert.Nil(t, err, "Should not return error")
	}

	testDb := new(mocks.DB)

	r, err := http.NewRequest("GET", "/secrets/list/secrets", nil)
	assert.Nil(t, err, "Should not return error")

	key := new(secrets.Key)
	err = key.New("968cd432-c97a-11e5-9956-625662870761")
	priv := key.Display()
	assert.Nil(t, err, "Should not return error")

	authSetup(testDb, r, priv)

	pos := 0

	testDb.On("ListSecrets", mock.Anything).Return(func(n int) ([]secrets.Secret, error) {
		start := pos
		end := pos + n
		if start >= len(secretList) {
			start = len(secretList)
		}
		if end >= len(secretList) {
			end = len(secretList)
		}
		pos = end
		return secretList[start:end], nil
	})

	database = testDb

	m.ServeHTTP(w, r)

	expected, err := json.MarshalIndent(secretList[0:10], "", "  ")
	assert.Nil(t, err, "Should not return error")
	buf, err := json.MarshalIndent(secretList[10:], "", "  ")
	assert.Nil(t, err, "Should not return error")
	expected = append(expected, buf...)

	assert.Equal(t, expected, w.Body.Bytes())

}
예제 #4
0
func TestView(t *testing.T) {
	w := httptest.NewRecorder()

	root, err := secrets.New("testsecret", []byte("testmessage"))
	assert.Nil(t, err, "Should not return error")

	key := new(secrets.Key)
	err = key.New("968cd432-c97a-11e5-9956-625662870761")
	priv := key.Display()
	assert.Nil(t, err, "Should not return error")

	shared, err := root.Share(key)
	assert.Nil(t, err, "Should not return error")

	req := request{Name: "testsecret"}
	data, err := json.Marshal(req)
	assert.Nil(t, err, "Should not return error")

	r, err := http.NewRequest("POST", "/secrets/view", bytes.NewReader(data))
	assert.Nil(t, err, "Should not return error")

	testDb := new(mocks.DB)

	authSetup(testDb, r, priv)

	testDb.On(
		"GetSharedSecret",
		&secrets.Secret{Name: "testsecret"},
		&secrets.Key{Name: "968cd432-c97a-11e5-9956-625662870761"}).Run(
		func(args mock.Arguments) {
			args.Get(0).(*secrets.Secret).Name = shared.Name
			args.Get(0).(*secrets.Secret).Nonce = shared.Nonce
			args.Get(0).(*secrets.Secret).Message = shared.Message
			args.Get(0).(*secrets.Secret).Pubkey = shared.Pubkey
			args.Get(0).(*secrets.Secret).Key = shared.Key
		}).Return(nil)

	testDb.On("GetRootSecret", mock.AnythingOfType("*secrets.Secret")).Run(func(args mock.Arguments) {
		args.Get(0).(*secrets.Secret).Name = root.Name
		args.Get(0).(*secrets.Secret).Nonce = root.Nonce
		args.Get(0).(*secrets.Secret).Message = root.Message
		args.Get(0).(*secrets.Secret).Pubkey = root.Pubkey
		args.Get(0).(*secrets.Secret).Key = root.Key
	}).Return(nil)

	database = testDb

	View(w, r)
	assert.Equal(t, "testmessage", string(w.Body.Bytes()))

}
예제 #5
0
func readDBcert() (cert []byte, err error) {
	root := new(secrets.Secret)
	shared := new(secrets.Secret)
	root.Name = certName
	shared.Name = certName

	key := new(secrets.Key)
	key.Name = certID

	priv, err := base64.StdEncoding.DecodeString(certKey)
	if err != nil {
		return
	}

	err = database.GetSharedSecret(shared, key)
	switch err {

	case gorm.ErrRecordNotFound:
		err = errors.New("Cert is not shared or does not exist")
		return

	case nil:
		break

	default:
		return

	}

	err = database.GetRootSecret(root)
	switch err {

	case gorm.ErrRecordNotFound:
		err = errors.New("Cert does not exist")
		return

	case nil:
		break

	default:
		return
	}

	return root.Decrypt(shared, priv)
}
예제 #6
0
// Key adds a new secret key to the vault
func Key(w http.ResponseWriter, r *http.Request) {
	api := newAPI(w, r)
	defer api.req.Body.Close()

	if !api.auth() || !api.admin {
		api.error("Unauthorized", 401)
		return
	}

	request, err := api.read()
	if err != nil {
		log.Debug(err)
		api.error("Bad request", 400)
		return
	}

	if request.Name == "" {
		request.Name = uuid.New()
	}

	if !secretIDRegex.MatchString(request.Name) {
		api.error("Invalid key ID", 400)
	}

	key := new(secrets.Key)

	err = key.New(request.Name)
	if err != nil {
		log.Error(err)
		api.error("Server error", 500)
		return
	}

	if request.Admin {
		key.ReadOnly = false
	} else {
		key.ReadOnly = true
	}

	err = database.AddKey(key)
	if err != nil {
		log.Error(err)
		api.error("Database error", 500)
		return
	}

	log.Info("New key added: ", key.Name)

	api.reply(secrets.Key{
		Name:     key.Name,
		Key:      key.Display(),
		ReadOnly: key.ReadOnly,
	},
		201)
}
예제 #7
0
// ListKeys returns an iterator function that walks through all keys in the database.
// The iterator takes an integer argument, which is the maximum number of results to return per iteration.
// If a secret name is specified, the results are limited to keys with access to that secret.
func (p *DB) ListKeys(secret *string) func(int) ([]secrets.Key, error) {
	pos := 0

	return func(n int) (res []secrets.Key, err error) {
		if err := p.refresh(); err != nil {
			return nil, err
		}

		var rows *sql.Rows

		if secret != nil {
			rows, err = p.conn.Table("keys").Select(
				"keys.id, keys.name, keys.key, keys.nonce, keys.public, keys.read_only").Joins(
				"left join secrets on keys.id = secrets.key_id").Where(
				"secrets.name = ?", *secret).Order("id asc").Limit(n).Offset(pos).Rows()
		} else {
			rows, err = p.conn.Table("keys").Select("id, name, key, nonce, public, read_only").Order("id asc").Limit(n).Offset(pos).Rows()
		}

		for rows.Next() {
			out := new(secrets.Key)
			var ro sql.NullBool
			err = rows.Scan(&out.ID, &out.Name, &out.Key, &out.Nonce, &out.Public, &ro)
			if err != nil {
				return
			}
			if ro.Valid {
				out.ReadOnly = ro.Bool
			} else {
				out.ReadOnly = false
			}
			res = append(res, *out)
		}
		err = rows.Close()
		pos += len(res)
		return
	}
}
예제 #8
0
// View downloads a decrypted message
func View(w http.ResponseWriter, r *http.Request) {
	api := newAPI(w, r)
	defer api.req.Body.Close()

	if !api.auth() {
		api.error("Unauthorized", 401)
		return
	}

	request, err := api.read()
	if err != nil {
		log.Debug(err)
		api.error("Bad request", 400)
		return
	}

	if name, ok := api.params["messageName"]; ok {
		request.Name = name
	}

	root := new(secrets.Secret)
	shared := new(secrets.Secret)
	root.Name = request.Name
	shared.Name = request.Name

	key := new(secrets.Key)
	key.Name = api.keyID

	err = database.GetSharedSecret(shared, key)
	switch err {

	case gorm.ErrRecordNotFound:
		api.error("Secret does not exist", 404)
		return

	case nil:
		break

	default:
		log.Error(err)
		api.error("Database error", 500)
		return

	}

	err = database.GetRootSecret(root)
	switch err {

	case gorm.ErrRecordNotFound:
		api.error("Secret does not exist", 404)
		return

	case nil:
		break

	default:
		log.Error(err)
		api.error("Database error", 500)
		return
	}

	message, err := root.Decrypt(shared, api.key)
	if err != nil {
		log.Debug(err)
		api.error("Cannot decrypt secret", 500)
		return
	}
	defer secrets.Zero(message)

	log.Info("Secret: ", shared.Name, " viewed by: ", key.Name)
	viewCount++

	api.rawMessage(message, 200)
}
예제 #9
0
// Share grants a key access to a message
func Share(w http.ResponseWriter, r *http.Request) {
	api := newAPI(w, r)
	defer api.req.Body.Close()

	if !api.auth() || !api.admin {
		api.error("Unauthorized", 401)
		return
	}

	request, err := api.read()
	if err != nil {
		log.Debug(err)
		api.error("Bad request", 400)
		return
	}

	if len(request.KeyID) == 0 {
		api.error("Missing elements in request", 400)
		return
	}
	if len(request.Name) == 0 {
		api.error("Missing elements in request", 400)
		return
	}

	key := new(secrets.Key)
	key.Name = request.KeyID
	key.Key = request.Key

	err = database.GetKey(key)
	if err != nil {
		log.Error(err)
		api.error("Database error", 500)
		return
	}

	secret := new(secrets.Secret)
	secret.Name = request.Name

	err = database.GetRootSecret(secret)
	switch err {

	case gorm.ErrRecordNotFound:
		api.error("Secret does not exist", 404)
		return

	case nil:
		break

	default:
		log.Error(err)
		api.error("Database error", 500)
		return

	}

	shared, err := secret.Share(key)
	if err != nil {
		log.Error(err)
		api.error(err.Error(), 500)
		return
	}

	err = database.AddSecret(shared)
	if err != nil {
		log.Error(err)
		api.error("Database error", 500)
		return
	}

	log.Info("Secret: ", shared.Name, " shared with: ", key.Name)

	api.message("OK", 201)
	return
}