예제 #1
0
파일: path_roles.go 프로젝트: mhurne/vault
func (b *backend) pathRoleCreate(
	req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
	name := data.Get("name").(string)
	sql := data.Get("sql").(string)
	username_length := data.Get("username_length").(int)
	rolename_length := data.Get("rolename_length").(int)
	displayname_length := data.Get("displayname_length").(int)

	// Get our connection
	db, err := b.DB(req.Storage)
	if err != nil {
		return nil, err
	}

	// Test the query by trying to prepare it
	for _, query := range strutil.ParseArbitraryStringSlice(sql, ";") {
		query = strings.TrimSpace(query)
		if len(query) == 0 {
			continue
		}

		stmt, err := db.Prepare(Query(query, map[string]string{
			"name":     "foo",
			"password": "******",
		}))
		if err != nil {
			return logical.ErrorResponse(fmt.Sprintf(
				"Error testing query: %s", err)), nil
		}
		stmt.Close()
	}

	// Store it
	entry, err := logical.StorageEntryJSON("role/"+name, &roleEntry{
		SQL:               sql,
		UsernameLength:    username_length,
		DisplaynameLength: displayname_length,
		RolenameLength:    rolename_length,
	})
	if err != nil {
		return nil, err
	}
	if err := req.Storage.Put(entry); err != nil {
		return nil, err
	}
	return nil, nil
}
예제 #2
0
func (b *backend) pathRoleCreateRead(
	req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
	b.logger.Println("[TRACE] postgres/pathRoleCreateRead: enter")
	defer b.logger.Println("[TRACE] postgres/pathRoleCreateRead: exit")
	name := data.Get("name").(string)

	// Get the role
	b.logger.Println("[TRACE] postgres/pathRoleCreateRead: getting role")
	role, err := b.Role(req.Storage, name)
	if err != nil {
		return nil, err
	}
	if role == nil {
		return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil
	}

	// Determine if we have a lease
	b.logger.Println("[TRACE] postgres/pathRoleCreateRead: getting lease")
	lease, err := b.Lease(req.Storage)
	if err != nil {
		return nil, err
	}
	// Unlike some other backends we need a lease here (can't leave as 0 and
	// let core fill it in) because Postgres also expires users as a safety
	// measure, so cannot be zero
	if lease == nil {
		lease = &configLease{
			Lease: b.System().DefaultLeaseTTL(),
		}
	}

	// Generate the username, password and expiration. PG limits user to 63 characters
	displayName := req.DisplayName
	if len(displayName) > 26 {
		displayName = displayName[:26]
	}
	userUUID, err := uuid.GenerateUUID()
	if err != nil {
		return nil, err
	}
	username := fmt.Sprintf("%s-%s", displayName, userUUID)
	if len(username) > 63 {
		username = username[:63]
	}
	password, err := uuid.GenerateUUID()
	if err != nil {
		return nil, err
	}
	expiration := time.Now().
		Add(lease.Lease).
		Format("2006-01-02 15:04:05-0700")

	// Get our handle
	b.logger.Println("[TRACE] postgres/pathRoleCreateRead: getting database handle")
	db, err := b.DB(req.Storage)
	if err != nil {
		return nil, err
	}

	// Start a transaction
	b.logger.Println("[TRACE] postgres/pathRoleCreateRead: starting transaction")
	tx, err := db.Begin()
	if err != nil {
		return nil, err
	}
	defer func() {
		b.logger.Println("[TRACE] postgres/pathRoleCreateRead: rolling back transaction")
		tx.Rollback()
	}()

	// Execute each query
	for _, query := range strutil.ParseArbitraryStringSlice(role.SQL, ";") {
		query = strings.TrimSpace(query)
		if len(query) == 0 {
			continue
		}

		b.logger.Println("[TRACE] postgres/pathRoleCreateRead: preparing statement")
		stmt, err := tx.Prepare(Query(query, map[string]string{
			"name":       username,
			"password":   password,
			"expiration": expiration,
		}))
		if err != nil {
			return nil, err
		}
		defer stmt.Close()
		b.logger.Println("[TRACE] postgres/pathRoleCreateRead: executing statement")
		if _, err := stmt.Exec(); err != nil {
			return nil, err
		}
	}

	// Commit the transaction

	b.logger.Println("[TRACE] postgres/pathRoleCreateRead: committing transaction")
	if err := tx.Commit(); err != nil {
		return nil, err
	}

	// Return the secret

	b.logger.Println("[TRACE] postgres/pathRoleCreateRead: generating secret")
	resp := b.Secret(SecretCredsType).Response(map[string]interface{}{
		"username": username,
		"password": password,
	}, map[string]interface{}{
		"username": username,
	})
	resp.Secret.TTL = lease.Lease
	return resp, nil
}
예제 #3
0
func (b *backend) secretCredsRevoke(
	req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
	// Get the username from the internal data
	usernameRaw, ok := req.Secret.InternalData["username"]
	if !ok {
		return nil, fmt.Errorf("secret is missing username internal data")
	}
	username, ok := usernameRaw.(string)

	var revocationSQL string
	var resp *logical.Response

	roleNameRaw, ok := req.Secret.InternalData["role"]
	if ok {
		role, err := b.Role(req.Storage, roleNameRaw.(string))
		if err != nil {
			return nil, err
		}
		if role == nil {
			if resp == nil {
				resp = &logical.Response{}
			}
			resp.AddWarning(fmt.Sprintf("Role %q cannot be found. Using default revocation SQL.", roleNameRaw.(string)))
		} else {
			revocationSQL = role.RevocationSQL
		}
	}

	// Get our connection
	db, err := b.DB(req.Storage)
	if err != nil {
		return nil, err
	}

	switch revocationSQL {

	// This is the default revocation logic. If revocation SQL is provided it
	// is simply executed as-is.
	case "":
		// Check if the role exists
		var exists bool
		err = db.QueryRow("SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists)
		if err != nil && err != sql.ErrNoRows {
			return nil, err
		}

		if exists == false {
			return resp, nil
		}

		// Query for permissions; we need to revoke permissions before we can drop
		// the role
		// This isn't done in a transaction because even if we fail along the way,
		// we want to remove as much access as possible
		stmt, err := db.Prepare("SELECT DISTINCT table_schema FROM information_schema.role_column_grants WHERE grantee=$1;")
		if err != nil {
			return nil, err
		}
		defer stmt.Close()

		rows, err := stmt.Query(username)
		if err != nil {
			return nil, err
		}
		defer rows.Close()

		const initialNumRevocations = 16
		revocationStmts := make([]string, 0, initialNumRevocations)
		for rows.Next() {
			var schema string
			err = rows.Scan(&schema)
			if err != nil {
				// keep going; remove as many permissions as possible right now
				continue
			}
			revocationStmts = append(revocationStmts, fmt.Sprintf(
				`REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA %s FROM %s;`,
				pq.QuoteIdentifier(schema),
				pq.QuoteIdentifier(username)))

			revocationStmts = append(revocationStmts, fmt.Sprintf(
				`REVOKE USAGE ON SCHEMA %s FROM %s;`,
				pq.QuoteIdentifier(schema),
				pq.QuoteIdentifier(username)))
		}

		// for good measure, revoke all privileges and usage on schema public
		revocationStmts = append(revocationStmts, fmt.Sprintf(
			`REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM %s;`,
			pq.QuoteIdentifier(username)))

		revocationStmts = append(revocationStmts, fmt.Sprintf(
			"REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM %s;",
			pq.QuoteIdentifier(username)))

		revocationStmts = append(revocationStmts, fmt.Sprintf(
			"REVOKE USAGE ON SCHEMA public FROM %s;",
			pq.QuoteIdentifier(username)))

		// get the current database name so we can issue a REVOKE CONNECT for
		// this username
		var dbname sql.NullString
		if err := db.QueryRow("SELECT current_database();").Scan(&dbname); err != nil {
			return nil, err
		}

		if dbname.Valid {
			revocationStmts = append(revocationStmts, fmt.Sprintf(
				`REVOKE CONNECT ON DATABASE %s FROM %s;`,
				pq.QuoteIdentifier(dbname.String),
				pq.QuoteIdentifier(username)))
		}

		// again, here, we do not stop on error, as we want to remove as
		// many permissions as possible right now
		var lastStmtError error
		for _, query := range revocationStmts {
			stmt, err := db.Prepare(query)
			if err != nil {
				lastStmtError = err
				continue
			}
			defer stmt.Close()
			_, err = stmt.Exec()
			if err != nil {
				lastStmtError = err
			}
		}

		// can't drop if not all privileges are revoked
		if rows.Err() != nil {
			return nil, fmt.Errorf("could not generate revocation statements for all rows: %s", rows.Err())
		}
		if lastStmtError != nil {
			return nil, fmt.Errorf("could not perform all revocation statements: %s", lastStmtError)
		}

		// Drop this user
		stmt, err = db.Prepare(fmt.Sprintf(
			`DROP ROLE IF EXISTS %s;`, pq.QuoteIdentifier(username)))
		if err != nil {
			return nil, err
		}
		defer stmt.Close()
		if _, err := stmt.Exec(); err != nil {
			return nil, err
		}

	// We have revocation SQL, execute directly, within a transaction
	default:
		tx, err := db.Begin()
		if err != nil {
			return nil, err
		}
		defer func() {
			tx.Rollback()
		}()

		for _, query := range strutil.ParseArbitraryStringSlice(revocationSQL, ";") {
			query = strings.TrimSpace(query)
			if len(query) == 0 {
				continue
			}

			stmt, err := tx.Prepare(Query(query, map[string]string{
				"name": username,
			}))
			if err != nil {
				return nil, err
			}
			defer stmt.Close()

			if _, err := stmt.Exec(); err != nil {
				return nil, err
			}
		}

		if err := tx.Commit(); err != nil {
			return nil, err
		}
	}

	return resp, nil
}
예제 #4
0
func (b *backend) pathCredsCreateRead(
	req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
	name := data.Get("name").(string)

	// Get the role
	role, err := b.Role(req.Storage, name)
	if err != nil {
		return nil, err
	}
	if role == nil {
		return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil
	}

	// Determine if we have a lease configuration
	leaseConfig, err := b.LeaseConfig(req.Storage)
	if err != nil {
		return nil, err
	}
	if leaseConfig == nil {
		leaseConfig = &configLease{}
	}

	// Generate our username and password
	displayName := req.DisplayName
	if len(displayName) > 10 {
		displayName = displayName[:10]
	}
	userUUID, err := uuid.GenerateUUID()
	if err != nil {
		return nil, err
	}
	username := fmt.Sprintf("%s-%s", displayName, userUUID)
	password, err := uuid.GenerateUUID()
	if err != nil {
		return nil, err
	}

	// Get our handle
	db, err := b.DB(req.Storage)
	if err != nil {
		return nil, err
	}

	// Start a transaction
	tx, err := db.Begin()
	if err != nil {
		return nil, err
	}
	defer tx.Rollback()

	// Always reset database to default db of connection.  Since it is in a
	// transaction, all statements will be on the same connection in the pool.
	roleSQL := fmt.Sprintf("USE [%s]; %s", b.defaultDb, role.SQL)

	// Execute each query
	for _, query := range strutil.ParseArbitraryStringSlice(roleSQL, ";") {
		query = strings.TrimSpace(query)
		if len(query) == 0 {
			continue
		}

		stmt, err := tx.Prepare(Query(query, map[string]string{
			"name":     username,
			"password": password,
		}))
		if err != nil {
			return nil, err
		}
		defer stmt.Close()
		if _, err := stmt.Exec(); err != nil {
			return nil, err
		}
	}

	// Commit the transaction
	if err := tx.Commit(); err != nil {
		return nil, err
	}

	// Return the secret
	resp := b.Secret(SecretCredsType).Response(map[string]interface{}{
		"username": username,
		"password": password,
	}, map[string]interface{}{
		"username": username,
	})
	resp.Secret.TTL = leaseConfig.TTL
	return resp, nil
}
예제 #5
0
func (b *backend) pathRoleCreateRead(
	req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
	name := data.Get("name").(string)

	// Get the role
	role, err := b.Role(req.Storage, name)
	if err != nil {
		return nil, err
	}
	if role == nil {
		return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil
	}

	// Determine if we have a lease
	lease, err := b.Lease(req.Storage)
	if err != nil {
		return nil, err
	}
	if lease == nil {
		lease = &configLease{}
	}

	// Generate our username and password. The username will be a
	// concatenation of:
	//
	// - the role name, truncated to role.rolenameLength (default 4)
	// - the token display name, truncated to role.displaynameLength (default 4)
	// - a UUID
	//
	// the entire contactenated string is then truncated to role.usernameLength,
	// which by default is 16 due to limitations in older but still-prevalant
	// versions of MySQL.
	roleName := name
	if len(roleName) > role.RolenameLength {
		roleName = roleName[:role.RolenameLength]
	}
	displayName := req.DisplayName
	if len(displayName) > role.DisplaynameLength {
		displayName = displayName[:role.DisplaynameLength]
	}
	userUUID, err := uuid.GenerateUUID()
	if err != nil {
		return nil, err
	}
	username := fmt.Sprintf("%s-%s-%s", roleName, displayName, userUUID)
	if len(username) > role.UsernameLength {
		username = username[:role.UsernameLength]
	}
	password, err := uuid.GenerateUUID()
	if err != nil {
		return nil, err
	}

	// Get our handle
	db, err := b.DB(req.Storage)
	if err != nil {
		return nil, err
	}

	// Start a transaction
	tx, err := db.Begin()
	if err != nil {
		return nil, err
	}
	defer tx.Rollback()

	// Execute each query
	for _, query := range strutil.ParseArbitraryStringSlice(role.SQL, ";") {
		query = strings.TrimSpace(query)
		if len(query) == 0 {
			continue
		}

		stmt, err := tx.Prepare(Query(query, map[string]string{
			"name":     username,
			"password": password,
		}))
		if err != nil {
			return nil, err
		}
		defer stmt.Close()
		if _, err := stmt.Exec(); err != nil {
			return nil, err
		}
	}

	// Commit the transaction
	if err := tx.Commit(); err != nil {
		return nil, err
	}

	// Return the secret
	resp := b.Secret(SecretCredsType).Response(map[string]interface{}{
		"username": username,
		"password": password,
	}, map[string]interface{}{
		"username": username,
	})
	resp.Secret.TTL = lease.Lease
	return resp, nil
}
예제 #6
0
func (b *backend) pathCredsCreateRead(
	req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
	name := data.Get("name").(string)

	// Get the role
	role, err := getRole(req.Storage, name)
	if err != nil {
		return nil, err
	}
	if role == nil {
		return logical.ErrorResponse(fmt.Sprintf("Unknown role: %s", name)), nil
	}

	displayName := req.DisplayName
	userUUID, err := uuid.GenerateUUID()
	if err != nil {
		return nil, err
	}
	username := fmt.Sprintf("vault_%s_%s_%s_%d", name, displayName, userUUID, time.Now().Unix())
	username = strings.Replace(username, "-", "_", -1)
	password, err := uuid.GenerateUUID()
	if err != nil {
		return nil, err
	}

	// Get our connection
	session, err := b.DB(req.Storage)
	if err != nil {
		return nil, err
	}

	// Set consistency
	if role.Consistency != "" {
		consistencyValue, err := gocql.ParseConsistencyWrapper(role.Consistency)
		if err != nil {
			return nil, err
		}

		session.SetConsistency(consistencyValue)
	}

	// Execute each query
	for _, query := range strutil.ParseArbitraryStringSlice(role.CreationCQL, ";") {
		query = strings.TrimSpace(query)
		if len(query) == 0 {
			continue
		}

		err = session.Query(substQuery(query, map[string]string{
			"username": username,
			"password": password,
		})).Exec()
		if err != nil {
			for _, query := range strutil.ParseArbitraryStringSlice(role.RollbackCQL, ";") {
				query = strings.TrimSpace(query)
				if len(query) == 0 {
					continue
				}

				session.Query(substQuery(query, map[string]string{
					"username": username,
					"password": password,
				})).Exec()
			}
			return nil, err
		}
	}

	// Return the secret
	resp := b.Secret(SecretCredsType).Response(map[string]interface{}{
		"username": username,
		"password": password,
	}, map[string]interface{}{
		"username": username,
		"role":     name,
	})
	resp.Secret.TTL = role.Lease

	return resp, nil
}
예제 #7
0
func (b *backend) secretCredsRevoke(
	req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
	var resp *logical.Response

	// Get the username from the internal data
	usernameRaw, ok := req.Secret.InternalData["username"]
	if !ok {
		return nil, fmt.Errorf("secret is missing username internal data")
	}
	username, ok := usernameRaw.(string)

	// Get our connection
	db, err := b.DB(req.Storage)
	if err != nil {
		return nil, err
	}

	roleName := ""
	roleNameRaw, ok := req.Secret.InternalData["role"]
	if ok {
		roleName = roleNameRaw.(string)
	}

	var role *roleEntry
	if roleName != "" {
		role, err = b.Role(req.Storage, roleName)
		if err != nil {
			return nil, err
		}
	}

	// Use a default SQL statement for revocation if one cannot be fetched from the role
	revocationSQL := defaultRevocationSQL

	if role != nil && role.RevocationSQL != "" {
		revocationSQL = role.RevocationSQL
	} else {
		if resp == nil {
			resp = &logical.Response{}
		}
		resp.AddWarning(fmt.Sprintf("Role %q cannot be found. Using default SQL for revoking user.", roleName))
	}

	// Start a transaction
	tx, err := db.Begin()
	if err != nil {
		return nil, err
	}
	defer tx.Rollback()

	for _, query := range strutil.ParseArbitraryStringSlice(revocationSQL, ";") {
		query = strings.TrimSpace(query)
		if len(query) == 0 {
			continue
		}

		// This is not a prepared statement because not all commands are supported
		// 1295: This command is not supported in the prepared statement protocol yet
		// Reference https://mariadb.com/kb/en/mariadb/prepare-statement/
		query = strings.Replace(query, "{{name}}", username, -1)
		_, err = tx.Exec(query)
		if err != nil {
			return nil, err
		}

	}

	// Commit the transaction
	if err := tx.Commit(); err != nil {
		return nil, err
	}

	return resp, nil
}