예제 #1
0
// getTableNames implements the SchemaAccessor interface.
func (p *planner) getTableNames(dbDesc *sqlbase.DatabaseDescriptor) (parser.TableNames, error) {
	if e, ok := p.session.virtualSchemas.getVirtualSchemaEntry(dbDesc.Name); ok {
		return e.tableNames(), nil
	}

	prefix := sqlbase.MakeNameMetadataKey(dbDesc.ID, "")
	sr, err := p.txn.Scan(prefix, prefix.PrefixEnd(), 0)
	if err != nil {
		return nil, err
	}

	var tableNames parser.TableNames
	for _, row := range sr {
		_, tableName, err := encoding.DecodeUnsafeStringAscending(
			bytes.TrimPrefix(row.Key, prefix), nil)
		if err != nil {
			return nil, err
		}
		tn := parser.TableName{
			DatabaseName: parser.Name(dbDesc.Name),
			TableName:    parser.Name(tableName),
		}
		tableNames = append(tableNames, tn)
	}
	return tableNames, nil
}
예제 #2
0
func userTablesAndDBsMatchingName(
	descs []sqlbase.Descriptor, name parser.TableName,
) ([]sqlbase.Descriptor, error) {
	tableName := name.TableName.Normalize()
	dbName := name.DatabaseName.Normalize()

	matches := make([]sqlbase.Descriptor, 0, len(descs))
	dbIDsToName := make(map[sqlbase.ID]string)
	for _, desc := range descs {
		if db := desc.GetDatabase(); db != nil {
			if db.ID == keys.SystemDatabaseID {
				continue // Not a user database.
			}
			if n := parser.Name(db.Name).Normalize(); dbName == "*" || n == dbName {
				matches = append(matches, desc)
				dbIDsToName[db.ID] = n
			}
			continue
		}
	}
	for _, desc := range descs {
		if table := desc.GetTable(); table != nil {
			if _, ok := dbIDsToName[table.ParentID]; !ok {
				continue
			}
			if tableName == "*" || parser.Name(table.Name).Normalize() == tableName {
				matches = append(matches, desc)
			}
		}
	}
	return matches, nil
}
예제 #3
0
파일: backup.go 프로젝트: hvaara/cockroach
func runRestore(cmd *cobra.Command, args []string) error {
	if len(args) != 1 {
		return errors.New("input basepath argument is required")
	}
	base := args[0]

	ctx := context.Background()
	kvDB, stopper, err := makeDBClient()
	if err != nil {
		return err
	}
	defer stopper.Stop()

	tableName := parser.TableName{
		DatabaseName: parser.Name(backupCtx.database),
		TableName:    parser.Name(backupCtx.table),
	}
	restored, err := sql.Restore(ctx, *kvDB, base, tableName)
	if err != nil {
		return err
	}
	for _, table := range restored {
		fmt.Printf("Restored table %q\n", table.Name)
	}

	fmt.Printf("Restored from %s\n", base)
	return nil
}
예제 #4
0
파일: upsert.go 프로젝트: knz/cockroach
// upsertExprsAndIndex returns the upsert conflict index and the (possibly
// synthetic) SET expressions used when a row conflicts.
func upsertExprsAndIndex(
	tableDesc *sqlbase.TableDescriptor,
	onConflict parser.OnConflict,
	insertCols []sqlbase.ColumnDescriptor,
) (parser.UpdateExprs, *sqlbase.IndexDescriptor, error) {
	if onConflict.IsUpsertAlias() {
		// The UPSERT syntactic sugar is the same as the longhand specifying the
		// primary index as the conflict index and SET expressions for the columns
		// in insertCols minus any columns in the conflict index. Example:
		// `UPSERT INTO abc VALUES (1, 2, 3)` is syntactic sugar for
		// `INSERT INTO abc VALUES (1, 2, 3) ON CONFLICT a DO UPDATE SET b = 2, c = 3`.
		conflictIndex := &tableDesc.PrimaryIndex
		indexColSet := make(map[sqlbase.ColumnID]struct{}, len(conflictIndex.ColumnIDs))
		for _, colID := range conflictIndex.ColumnIDs {
			indexColSet[colID] = struct{}{}
		}
		updateExprs := make(parser.UpdateExprs, 0, len(insertCols))
		for _, c := range insertCols {
			if _, ok := indexColSet[c.ID]; !ok {
				names := parser.UnresolvedNames{
					parser.UnresolvedName{parser.Name(c.Name)},
				}
				expr := &parser.ColumnItem{
					TableName:  upsertExcludedTable,
					ColumnName: parser.Name(c.Name),
				}
				updateExprs = append(updateExprs, &parser.UpdateExpr{Names: names, Expr: expr})
			}
		}
		return updateExprs, conflictIndex, nil
	}

	indexMatch := func(index sqlbase.IndexDescriptor) bool {
		if !index.Unique {
			return false
		}
		if len(index.ColumnNames) != len(onConflict.Columns) {
			return false
		}
		for i, colName := range index.ColumnNames {
			if parser.ReNormalizeName(colName) != onConflict.Columns[i].Normalize() {
				return false
			}
		}
		return true
	}

	if indexMatch(tableDesc.PrimaryIndex) {
		return onConflict.Exprs, &tableDesc.PrimaryIndex, nil
	}
	for _, index := range tableDesc.Indexes {
		if indexMatch(index) {
			return onConflict.Exprs, &index, nil
		}
	}
	return nil, nil, fmt.Errorf("there is no unique or exclusion constraint matching the ON CONFLICT specification")
}
예제 #5
0
func (e virtualSchemaEntry) tableNames() parser.TableNames {
	var res parser.TableNames
	for _, tableName := range e.orderedTableNames {
		tn := parser.TableName{
			DatabaseName: parser.Name(e.desc.Name),
			TableName:    parser.Name(tableName),
		}
		res = append(res, tn)
	}
	return res
}
예제 #6
0
// getQualifiedTableName returns the database-qualified name of the table
// or view represented by the provided descriptor.
func (p *planner) getQualifiedTableName(desc *sqlbase.TableDescriptor) (string, error) {
	dbDesc, err := sqlbase.GetDatabaseDescFromID(p.txn, desc.ParentID)
	if err != nil {
		return "", err
	}
	tbName := parser.TableName{
		DatabaseName: parser.Name(dbDesc.Name),
		TableName:    parser.Name(desc.Name),
	}
	return tbName.String(), nil
}
예제 #7
0
func testInitDummySelectNode(desc *sqlbase.TableDescriptor) *renderNode {
	p := makeTestPlanner()
	scan := &scanNode{p: p}
	scan.desc = *desc
	scan.initDescDefaults(publicColumns)

	sel := &renderNode{planner: p}
	sel.source.plan = scan
	testName := parser.TableName{TableName: parser.Name(desc.Name), DatabaseName: parser.Name("test")}
	sel.source.info = newSourceInfoForSingleTable(testName, scan.Columns())
	sel.sourceInfo = multiSourceInfo{sel.source.info}
	sel.ivarHelper = parser.MakeIndexedVarHelper(sel, len(scan.Columns()))

	return sel
}
예제 #8
0
func queryNamespace(conn *sqlConn, parentID sqlbase.ID, name string) (sqlbase.ID, error) {
	rows, err := makeQuery(
		`SELECT id FROM system.namespace WHERE parentID = $1 AND name = $2`,
		parentID, parser.Name(name).Normalize())(conn)
	if err != nil {
		return 0, err
	}
	defer func() { _ = rows.Close() }()

	if err != nil {
		return 0, fmt.Errorf("%s not found: %v", name, err)
	}
	if len(rows.Columns()) != 1 {
		return 0, fmt.Errorf("unexpected result columns: %d", len(rows.Columns()))
	}
	vals := make([]driver.Value, 1)
	if err := rows.Next(vals); err != nil {
		return 0, err
	}
	switch t := vals[0].(type) {
	case int64:
		return sqlbase.ID(t), nil
	default:
		return 0, fmt.Errorf("unexpected result type: %T", vals[0])
	}
}
예제 #9
0
파일: user.go 프로젝트: knz/cockroach
// GetUserHashedPassword returns the hashedPassword for the given username if
// found in system.users.
func GetUserHashedPassword(
	ctx context.Context, executor *Executor, metrics *MemoryMetrics, username string,
) ([]byte, error) {
	normalizedUsername := parser.Name(username).Normalize()
	// The root user is not in system.users.
	if normalizedUsername == security.RootUser {
		return nil, nil
	}

	var hashedPassword []byte
	if err := executor.cfg.DB.Txn(ctx, func(txn *client.Txn) error {
		p := makeInternalPlanner("get-pwd", txn, security.RootUser, metrics)
		defer finishInternalPlanner(p)
		const getHashedPassword = `SELECT hashedPassword FROM system.users ` +
			`WHERE username=$1`
		values, err := p.queryRow(getHashedPassword, normalizedUsername)
		if err != nil {
			return errors.Errorf("error looking up user %s", normalizedUsername)
		}
		if len(values) == 0 {
			return errors.Errorf("user %s does not exist", normalizedUsername)
		}
		hashedPassword = []byte(*(values[0].(*parser.DBytes)))
		return nil
	}); err != nil {
		return nil, err
	}

	return hashedPassword, nil
}
예제 #10
0
// quoteName quotes based on Traditional syntax and adds commas between names.
func quoteNames(names ...string) string {
	nameList := make(parser.NameList, len(names))
	for i, n := range names {
		nameList[i] = parser.Name(n)
	}
	return parser.AsString(nameList)
}
예제 #11
0
// extractFunc returns an aggregateFuncHolder for a given SelectExpr specifying
// an aggregation function.
func (ag *aggregator) extractFunc(
	expr AggregatorSpec_Expr, eh *exprHelper,
) (*aggregateFuncHolder, error) {
	if expr.Func == AggregatorSpec_IDENT {
		fn := ag.newAggregateFuncHolder(parser.NewIdentAggregate)
		return fn, nil
	}

	// In order to reuse the aggregate functions as defined in the parser
	// package we are relying on the fact the each name defined in the Func enum
	// within the AggregatorSpec matches a SQL function name known to the parser.
	// See pkg/sql/parser/aggregate_builtins.go for the aggregate builtins we
	// are repurposing.
	p := &parser.FuncExpr{
		Func: parser.ResolvableFunctionReference{
			FunctionReference: parser.UnresolvedName{parser.Name(expr.Func.String())},
		},
		Exprs: []parser.Expr{eh.vars.IndexedVar(int(expr.ColIdx))},
	}

	_, err := p.TypeCheck(nil, parser.NoTypePreference)
	if err != nil {
		return nil, err
	}
	if agg := p.GetAggregateConstructor(); agg != nil {
		fn := ag.newAggregateFuncHolder(agg)
		if expr.Distinct {
			fn.seen = make(map[string]struct{})
		}
		return fn, nil
	}
	return nil, errors.Errorf("unable to get aggregate constructor for %s",
		AggregatorSpec_Func_name[int32(expr.Func)])
}
예제 #12
0
// newReturningHelper creates a new returningHelper for use by an
// insert/update node.
func (p *planner) newReturningHelper(
	r parser.ReturningExprs,
	desiredTypes []parser.Type,
	alias string,
	tablecols []sqlbase.ColumnDescriptor,
) (*returningHelper, error) {
	rh := &returningHelper{
		p: p,
	}
	if len(r) == 0 {
		return rh, nil
	}

	for _, e := range r {
		if err := p.parser.AssertNoAggregationOrWindowing(
			e.Expr, "RETURNING", p.session.SearchPath,
		); err != nil {
			return nil, err
		}
	}

	rh.columns = make(ResultColumns, 0, len(r))
	aliasTableName := parser.TableName{TableName: parser.Name(alias)}
	rh.source = newSourceInfoForSingleTable(aliasTableName, makeResultColumns(tablecols))
	rh.exprs = make([]parser.TypedExpr, 0, len(r))
	ivarHelper := parser.MakeIndexedVarHelper(rh, len(tablecols))
	for i, target := range r {
		// Pre-normalize VarNames at the top level so that checkRenderStar can see stars.
		if err := target.NormalizeTopLevelVarName(); err != nil {
			return nil, err
		}

		if isStar, cols, typedExprs, err := checkRenderStar(target, rh.source, ivarHelper); err != nil {
			return nil, err
		} else if isStar {
			rh.exprs = append(rh.exprs, typedExprs...)
			rh.columns = append(rh.columns, cols...)
			continue
		}

		// When generating an output column name it should exactly match the original
		// expression, so determine the output column name before we perform any
		// manipulations to the expression.
		outputName := getRenderColName(target)

		desired := parser.TypeAny
		if len(desiredTypes) > i {
			desired = desiredTypes[i]
		}

		typedExpr, err := rh.p.analyzeExpr(target.Expr, multiSourceInfo{rh.source}, ivarHelper, desired, false, "")
		if err != nil {
			return nil, err
		}
		rh.exprs = append(rh.exprs, typedExpr)
		rh.columns = append(rh.columns, ResultColumn{Name: outputName, Typ: typedExpr.ResolvedType()})
	}
	return rh, nil
}
예제 #13
0
파일: ordering.go 프로젝트: knz/cockroach
// Format pretty-prints the orderingInfo to a stream.
// If columns is not nil, column names are printed instead of column indexes.
func (ord orderingInfo) Format(buf *bytes.Buffer, columns ResultColumns) {
	sep := ""

	// Print the exact match columns. We sort them to ensure
	// a deterministic output order.
	cols := make([]int, 0, len(ord.exactMatchCols))
	for i := range ord.exactMatchCols {
		cols = append(cols, i)
	}
	sort.Ints(cols)

	for _, i := range cols {
		buf.WriteString(sep)
		sep = ","

		buf.WriteByte('=')
		if columns == nil || i >= len(columns) {
			fmt.Fprintf(buf, "%d", i)
		} else {
			parser.Name(columns[i].Name).Format(buf, parser.FmtSimple)
		}
	}

	// Print the ordering columns and for each their sort order.
	for _, o := range ord.ordering {
		buf.WriteString(sep)
		sep = ","

		prefix := byte('+')
		if o.Direction == encoding.Descending {
			prefix = '-'
		}
		buf.WriteByte(prefix)
		if columns == nil || o.ColIdx >= len(columns) {
			fmt.Fprintf(buf, "%d", o.ColIdx)
		} else {
			parser.Name(columns[o.ColIdx].Name).Format(buf, parser.FmtSimple)
		}
	}

	if ord.unique {
		buf.WriteString(sep)
		buf.WriteString("unique")
	}
}
예제 #14
0
// indexDefFromDescriptor creates an index definition (`CREATE INDEX ... ON (...)`) from
// and index descriptor by reconstructing a CreateIndex parser node and calling its
// String method.
func indexDefFromDescriptor(
	p *planner,
	db *sqlbase.DatabaseDescriptor,
	table *sqlbase.TableDescriptor,
	index *sqlbase.IndexDescriptor,
) (string, error) {
	indexDef := parser.CreateIndex{
		Name: parser.Name(index.Name),
		Table: parser.NormalizableTableName{
			TableNameReference: &parser.TableName{
				DatabaseName: parser.Name(db.Name),
				TableName:    parser.Name(table.Name),
			},
		},
		Unique:  index.Unique,
		Columns: make(parser.IndexElemList, len(index.ColumnNames)),
		Storing: make(parser.NameList, len(index.StoreColumnNames)),
	}
	for i, name := range index.ColumnNames {
		elem := parser.IndexElem{
			Column:    parser.Name(name),
			Direction: parser.Ascending,
		}
		if index.ColumnDirections[i] == sqlbase.IndexDescriptor_DESC {
			elem.Direction = parser.Descending
		}
		indexDef.Columns[i] = elem
	}
	for i, name := range index.StoreColumnNames {
		indexDef.Storing[i] = parser.Name(name)
	}
	if len(index.Interleave.Ancestors) > 0 {
		intl := index.Interleave
		parentTable, err := sqlbase.GetTableDescFromID(p.txn, intl.Ancestors[len(intl.Ancestors)-1].TableID)
		if err != nil {
			return "", err
		}
		var sharedPrefixLen int
		for _, ancestor := range intl.Ancestors {
			sharedPrefixLen += int(ancestor.SharedPrefixLen)
		}
		fields := index.ColumnNames[:sharedPrefixLen]
		intlDef := &parser.InterleaveDef{
			Parent: parser.NormalizableTableName{
				TableNameReference: &parser.TableName{
					TableName: parser.Name(parentTable.Name),
				},
			},
			Fields: make(parser.NameList, len(fields)),
		}
		for i, field := range fields {
			intlDef.Fields[i] = parser.Name(field)
		}
		indexDef.Interleave = intlDef
	}
	return indexDef.String(), nil
}
예제 #15
0
파일: internal.go 프로젝트: knz/cockroach
// GetTableSpan gets the key span for a SQL table, including any indices.
func (ie InternalExecutor) GetTableSpan(
	user string, txn *client.Txn, dbName, tableName string,
) (roachpb.Span, error) {
	// Lookup the table ID.
	p := makeInternalPlanner("get-table-span", txn, user, ie.LeaseManager.memMetrics)
	defer finishInternalPlanner(p)
	p.leaseMgr = ie.LeaseManager

	tn := parser.TableName{DatabaseName: parser.Name(dbName), TableName: parser.Name(tableName)}
	tableID, err := getTableID(p, &tn)
	if err != nil {
		return roachpb.Span{}, err
	}

	// Determine table data span.
	tablePrefix := keys.MakeTablePrefix(uint32(tableID))
	tableStartKey := roachpb.Key(tablePrefix)
	tableEndKey := tableStartKey.PrefixEnd()
	return roachpb.Span{Key: tableStartKey, EndKey: tableEndKey}, nil
}
예제 #16
0
func (p *planner) getGeneratorPlan(t *parser.FuncExpr) (planDataSource, error) {
	plan, name, err := p.makeGenerator(t)
	if err != nil {
		return planDataSource{}, err
	}
	tn := parser.TableName{TableName: parser.Name(name)}
	return planDataSource{
		info: newSourceInfoForSingleTable(tn, plan.Columns()),
		plan: plan,
	}, nil
}
예제 #17
0
파일: v3.go 프로젝트: BramGruneir/cockroach
// handleAuthentication should discuss with the client to arrange
// authentication and update c.sessionArgs with the authenticated user's
// name, if different from the one given initially. Note: at this
// point the sql.Session does not exist yet! If need exists to access the
// database to look up authentication data, use the internal executor.
func (c *v3Conn) handleAuthentication(ctx context.Context, insecure bool) error {
	if tlsConn, ok := c.conn.(*tls.Conn); ok {
		var authenticationHook security.UserAuthHook

		// Check that the requested user exists and retrieve the hashed
		// password in case password authentication is needed.
		hashedPassword, err := sql.GetUserHashedPassword(
			ctx, c.executor, c.metrics.internalMemMetrics, c.sessionArgs.User,
		)
		if err != nil {
			return c.sendInternalError(err.Error())
		}

		tlsState := tlsConn.ConnectionState()
		// If no certificates are provided, default to password
		// authentication.
		if len(tlsState.PeerCertificates) == 0 {
			password, err := c.sendAuthPasswordRequest()
			if err != nil {
				return c.sendInternalError(err.Error())
			}
			authenticationHook = security.UserAuthPasswordHook(
				insecure, password, hashedPassword,
			)
		} else {
			// Normalize the username contained in the certificate.
			tlsState.PeerCertificates[0].Subject.CommonName = parser.Name(
				tlsState.PeerCertificates[0].Subject.CommonName,
			).Normalize()
			var err error
			authenticationHook, err = security.UserAuthCertHook(insecure, &tlsState)
			if err != nil {
				return c.sendInternalError(err.Error())
			}
		}

		if err := authenticationHook(c.sessionArgs.User, true /* public */); err != nil {
			return c.sendInternalError(err.Error())
		}
	}

	c.writeBuf.initMsg(serverMsgAuth)
	c.writeBuf.putInt32(authOK)
	return c.writeBuf.finishMsg(c.wr)
}
예제 #18
0
파일: join.go 프로젝트: maxlang/cockroach
// commonColumns returns the names of columns common on the
// right and left sides, for use by NATURAL JOIN.
func commonColumns(left, right *dataSourceInfo) parser.NameList {
	var res parser.NameList
	for _, cLeft := range left.sourceColumns {
		if cLeft.hidden {
			continue
		}
		for _, cRight := range right.sourceColumns {
			if cRight.hidden {
				continue
			}

			if parser.ReNormalizeName(cLeft.Name) == parser.ReNormalizeName(cRight.Name) {
				res = append(res, parser.Name(cLeft.Name))
			}
		}
	}
	return res
}
예제 #19
0
func (s *selectNode) ExplainPlan(v bool) (name, description string, children []planNode) {
	subplans := []planNode{s.source.plan}

	subplans = s.planner.collectSubqueryPlans(s.filter, subplans)

	for _, e := range s.render {
		subplans = s.planner.collectSubqueryPlans(e, subplans)
	}

	if len(subplans) == 1 && !v {
		return s.source.plan.ExplainPlan(v)
	}

	var buf bytes.Buffer

	buf.WriteString("from (")
	for i, col := range s.source.info.sourceColumns {
		if i > 0 {
			buf.WriteString(", ")
		}
		if col.hidden {
			buf.WriteByte('*')
		}
		alias, found := s.source.info.findTableAlias(i)
		if found {
			parser.FormatNode(&buf, parser.FmtSimple, &alias)
		} else {
			buf.WriteByte('_')
		}
		buf.WriteByte('.')
		parser.FormatNode(&buf, parser.FmtSimple, parser.Name(col.Name))
	}
	buf.WriteByte(')')

	name = "render/filter"
	if s.explain != explainNone {
		name = fmt.Sprintf("%s(%s)", name, explainStrings[s.explain])
	}

	return name, buf.String(), subplans
}
예제 #20
0
// databaseFromSearchPath returns the first database in the session's SearchPath
// that contains the specified table. If the table can't be found, we return the
// session database.
func (p *planner) databaseFromSearchPath(tn *parser.TableName) (string, error) {
	t := *tn
	for _, database := range p.session.SearchPath {
		t.DatabaseName = parser.Name(database)
		desc, err := p.getTableOrViewDesc(&t)
		if err != nil {
			if _, ok := err.(*sqlbase.ErrUndefinedDatabase); ok {
				// Keep iterating through search path if a database in the search path
				// doesn't exist.
				continue
			}
			return "", err
		}
		if desc != nil {
			// The table or view exists in this database, so return it.
			return t.Database(), nil
		}
	}
	// If we couldn't find the table or view in the search path, default to the
	// database set by the user.
	return p.session.Database, nil
}
예제 #21
0
// getVirtualDataSource attempts to find a virtual table with the
// given name.
func (p *planner) getVirtualDataSource(tn *parser.TableName) (planDataSource, bool, error) {
	virtual, err := p.session.virtualSchemas.getVirtualTableEntry(tn)
	if err != nil {
		return planDataSource{}, false, err
	}
	if virtual.desc != nil {
		columns, constructor := virtual.getPlanInfo()
		sourceName := parser.TableName{
			TableName:    parser.Name(virtual.desc.Name),
			DatabaseName: tn.DatabaseName,
		}
		return planDataSource{
			info: newSourceInfoForSingleTable(sourceName, columns),
			plan: &delayedNode{
				name:        sourceName.String(),
				columns:     columns,
				constructor: constructor,
			},
		}, true, nil
	}
	return planDataSource{}, false, nil
}
예제 #22
0
// formatColumns converts a column signature for a data source /
// planNode to a string. The column types are printed iff the 2nd
// argument specifies so.
func formatColumns(cols ResultColumns, printTypes bool) string {
	var buf bytes.Buffer
	buf.WriteByte('(')
	for i, rCol := range cols {
		if i > 0 {
			buf.WriteString(", ")
		}
		parser.Name(rCol.Name).Format(&buf, parser.FmtSimple)
		// Output extra properties like [hidden,omitted].
		hasProps := false
		outputProp := func(prop string) {
			if hasProps {
				buf.WriteByte(',')
			} else {
				buf.WriteByte('[')
			}
			hasProps = true
			buf.WriteString(prop)
		}
		if rCol.hidden {
			outputProp("hidden")
		}
		if rCol.omitted {
			outputProp("omitted")
		}
		if hasProps {
			buf.WriteByte(']')
		}

		if printTypes {
			buf.WriteByte(' ')
			buf.WriteString(rCol.Typ.String())
		}
	}
	buf.WriteByte(')')
	return buf.String()
}
예제 #23
0
// newReturningHelper creates a new returningHelper for use by an
// insert/update node.
func (p *planner) newReturningHelper(
	r parser.ReturningExprs,
	desiredTypes []parser.Type,
	alias string,
	tablecols []sqlbase.ColumnDescriptor,
) (*returningHelper, error) {
	rh := &returningHelper{
		p: p,
	}
	if len(r) == 0 {
		return rh, nil
	}

	for _, e := range r {
		if err := p.parser.AssertNoAggregationOrWindowing(
			e.Expr, "RETURNING", p.session.SearchPath,
		); err != nil {
			return nil, err
		}
	}

	rh.columns = make(ResultColumns, 0, len(r))
	aliasTableName := parser.TableName{TableName: parser.Name(alias)}
	rh.source = newSourceInfoForSingleTable(aliasTableName, makeResultColumns(tablecols))
	rh.exprs = make([]parser.TypedExpr, 0, len(r))
	ivarHelper := parser.MakeIndexedVarHelper(rh, len(tablecols))
	for _, target := range r {
		cols, typedExprs, _, err := p.computeRender(target, parser.TypeAny, rh.source, ivarHelper, true)
		if err != nil {
			return nil, err
		}
		rh.columns = append(rh.columns, cols...)
		rh.exprs = append(rh.exprs, typedExprs...)
	}
	return rh, nil
}
예제 #24
0
파일: dump.go 프로젝트: hvaara/cockroach
func dumpTable(w io.Writer, conn *sqlConn, origDBName, origTableName string) error {
	const limit = 100

	// Escape names since they can't be used in placeholders.
	dbname := parser.Name(origDBName).String()
	tablename := parser.Name(origTableName).String()

	if err := conn.Exec(fmt.Sprintf("SET DATABASE = %s", dbname), nil); err != nil {
		return err
	}

	// Fetch all table metadata in a transaction and its time to guarantee it
	// doesn't change between the various SHOW statements.
	if err := conn.Exec("BEGIN", nil); err != nil {
		return err
	}

	vals, err := conn.QueryRow("SELECT cluster_logical_timestamp()", nil)
	if err != nil {
		return err
	}
	clusterTS := string(vals[0].([]byte))

	// A previous version of the code did a SELECT on system.descriptor. This
	// required the SELECT privilege to the descriptor table, which only root
	// has. Allowing non-root to do this would let users see other users' table
	// descriptors which is a problem in multi-tenancy.

	// Fetch column types.
	rows, err := conn.Query(fmt.Sprintf("SHOW COLUMNS FROM %s", tablename), nil)
	if err != nil {
		return err
	}
	vals = make([]driver.Value, 2)
	coltypes := make(map[string]string)
	for {
		if err := rows.Next(vals); err == io.EOF {
			break
		} else if err != nil {
			return err
		}
		nameI, typI := vals[0], vals[1]
		name, ok := nameI.(string)
		if !ok {
			return fmt.Errorf("unexpected value: %T", nameI)
		}
		typ, ok := typI.(string)
		if !ok {
			return fmt.Errorf("unexpected value: %T", typI)
		}
		coltypes[name] = typ
	}
	if err := rows.Close(); err != nil {
		return err
	}

	// index holds the names, in order, of the primary key columns.
	var index []string
	// Primary index is always the first index returned by SHOW INDEX.
	rows, err = conn.Query(fmt.Sprintf("SHOW INDEX FROM %s", tablename), nil)
	if err != nil {
		return err
	}
	vals = make([]driver.Value, 5)
	var primaryIndex string
	// Find the primary index columns.
	for {
		if err := rows.Next(vals); err == io.EOF {
			break
		} else if err != nil {
			return err
		}
		b, ok := vals[1].(string)
		if !ok {
			return fmt.Errorf("unexpected value: %T", vals[1])
		}
		if primaryIndex == "" {
			primaryIndex = b
		} else if primaryIndex != b {
			break
		}
		b, ok = vals[4].(string)
		if !ok {
			return fmt.Errorf("unexpected value: %T", vals[4])
		}
		index = append(index, parser.Name(b).String())
	}
	if err := rows.Close(); err != nil {
		return err
	}
	if len(index) == 0 {
		return fmt.Errorf("no primary key index found")
	}
	indexes := strings.Join(index, ", ")

	// Build the SELECT query.
	var sbuf bytes.Buffer
	fmt.Fprintf(&sbuf, "SELECT %s, * FROM %s@%s AS OF SYSTEM TIME %s", indexes, tablename, primaryIndex, clusterTS)

	var wbuf bytes.Buffer
	fmt.Fprintf(&wbuf, " WHERE ROW (%s) > ROW (", indexes)
	for i := range index {
		if i > 0 {
			wbuf.WriteString(", ")
		}
		fmt.Fprintf(&wbuf, "$%d", i+1)
	}
	wbuf.WriteString(")")
	// No WHERE clause first time, so add a place to inject it.
	fmt.Fprintf(&sbuf, "%%s ORDER BY %s LIMIT %d", indexes, limit)
	bs := sbuf.String()

	vals, err = conn.QueryRow(fmt.Sprintf("SHOW CREATE TABLE %s", tablename), nil)
	if err != nil {
		return err
	}
	create := vals[1].(string)
	if _, err := w.Write([]byte(create)); err != nil {
		return err
	}
	if _, err := w.Write([]byte(";\n")); err != nil {
		return err
	}

	if err := conn.Exec("COMMIT", nil); err != nil {
		return err
	}

	// pk holds the last values of the fetched primary keys
	var pk []driver.Value
	q := fmt.Sprintf(bs, "")
	for {
		rows, err := conn.Query(q, pk)
		if err != nil {
			return err
		}
		cols := rows.Columns()
		pkcols := cols[:len(index)]
		cols = cols[len(index):]
		inserts := make([][]string, 0, limit)
		i := 0
		for i < limit {
			vals := make([]driver.Value, len(cols)+len(pkcols))
			if err := rows.Next(vals); err == io.EOF {
				break
			} else if err != nil {
				return err
			}
			if pk == nil {
				q = fmt.Sprintf(bs, wbuf.String())
			}
			pk = vals[:len(index)]
			vals = vals[len(index):]
			ivals := make([]string, len(vals))
			// Values need to be correctly encoded for INSERT statements in a text file.
			for si, sv := range vals {
				switch t := sv.(type) {
				case nil:
					ivals[si] = "NULL"
				case bool:
					ivals[si] = parser.MakeDBool(parser.DBool(t)).String()
				case int64:
					ivals[si] = parser.NewDInt(parser.DInt(t)).String()
				case float64:
					ivals[si] = parser.NewDFloat(parser.DFloat(t)).String()
				case string:
					ivals[si] = parser.NewDString(t).String()
				case []byte:
					switch ct := coltypes[cols[si]]; ct {
					case "INTERVAL":
						ivals[si] = fmt.Sprintf("'%s'", t)
					case "BYTES":
						ivals[si] = parser.NewDBytes(parser.DBytes(t)).String()
					default:
						// STRING and DECIMAL types can have optional length
						// suffixes, so only examine the prefix of the type.
						if strings.HasPrefix(coltypes[cols[si]], "STRING") {
							ivals[si] = parser.NewDString(string(t)).String()
						} else if strings.HasPrefix(coltypes[cols[si]], "DECIMAL") {
							ivals[si] = string(t)
						} else {
							panic(errors.Errorf("unknown []byte type: %s, %v: %s", t, cols[si], coltypes[cols[si]]))
						}
					}
				case time.Time:
					var d parser.Datum
					ct := coltypes[cols[si]]
					switch ct {
					case "DATE":
						d = parser.NewDDateFromTime(t, time.UTC)
					case "TIMESTAMP":
						d = parser.MakeDTimestamp(t, time.Nanosecond)
					case "TIMESTAMP WITH TIME ZONE":
						d = parser.MakeDTimestampTZ(t, time.Nanosecond)
					default:
						panic(errors.Errorf("unknown timestamp type: %s, %v: %s", t, cols[si], coltypes[cols[si]]))
					}
					ivals[si] = fmt.Sprintf("'%s'", d)
				default:
					panic(errors.Errorf("unknown field type: %T (%s)", t, cols[si]))
				}
			}
			inserts = append(inserts, ivals)
			i++
		}
		for si, sv := range pk {
			b, ok := sv.([]byte)
			if ok && strings.HasPrefix(coltypes[pkcols[si]], "STRING") {
				// Primary key strings need to be converted to a go string, but not SQL
				// encoded since they aren't being written to a text file.
				pk[si] = string(b)
			}
		}
		if err := rows.Close(); err != nil {
			return err
		}
		if i == 0 {
			break
		}
		fmt.Fprintf(w, "\nINSERT INTO %s VALUES", tablename)
		for idx, values := range inserts {
			if idx > 0 {
				fmt.Fprint(w, ",")
			}
			fmt.Fprint(w, "\n\t(")
			for vi, v := range values {
				if vi > 0 {
					fmt.Fprint(w, ", ")
				}
				fmt.Fprint(w, v)
			}
			fmt.Fprint(w, ")")
		}
		fmt.Fprintln(w, ";")
		if i < limit {
			break
		}
	}
	return nil
}
예제 #25
0
// getDataSource builds a planDataSource from a single data source clause
// (TableExpr) in a SelectClause.
func (p *planner) getDataSource(
	src parser.TableExpr, hints *parser.IndexHints, scanVisibility scanVisibility,
) (planDataSource, error) {
	switch t := src.(type) {
	case *parser.NormalizableTableName:
		// Usual case: a table.
		tn, err := p.QualifyWithDatabase(t)
		if err != nil {
			return planDataSource{}, err
		}

		// Is this perhaps a name for a virtual table?
		ds, foundVirtual, err := p.getVirtualDataSource(tn)
		if err != nil {
			return planDataSource{}, err
		}
		if foundVirtual {
			return ds, nil
		}
		return p.getTableScanOrViewPlan(tn, hints, scanVisibility)

	case *parser.Subquery:
		return p.getSubqueryPlan(t.Select, nil)

	case *parser.JoinTableExpr:
		// Joins: two sources.
		left, err := p.getDataSource(t.Left, nil, scanVisibility)
		if err != nil {
			return left, err
		}
		right, err := p.getDataSource(t.Right, nil, scanVisibility)
		if err != nil {
			return right, err
		}
		return p.makeJoin(t.Join, left, right, t.Cond)

	case *parser.ParenTableExpr:
		return p.getDataSource(t.Expr, hints, scanVisibility)

	case *parser.AliasedTableExpr:
		// Alias clause: source AS alias(cols...)
		src, err := p.getDataSource(t.Expr, t.Hints, scanVisibility)
		if err != nil {
			return src, err
		}

		var tableAlias parser.TableName
		if t.As.Alias != "" {
			// If an alias was specified, use that.
			tableAlias.TableName = parser.Name(t.As.Alias.Normalize())
			src.info.sourceAliases = sourceAliases{
				tableAlias: fillColumnRange(0, len(src.info.sourceColumns)-1),
			}
		}
		colAlias := t.As.Cols

		if len(colAlias) > 0 {
			// Make a copy of the slice since we are about to modify the contents.
			src.info.sourceColumns = append(ResultColumns(nil), src.info.sourceColumns...)

			// The column aliases can only refer to explicit columns.
			for colIdx, aliasIdx := 0, 0; aliasIdx < len(colAlias); colIdx++ {
				if colIdx >= len(src.info.sourceColumns) {
					var srcName string
					if tableAlias.DatabaseName != "" {
						srcName = tableAlias.String()
					} else {
						srcName = tableAlias.TableName.String()
					}

					return planDataSource{}, errors.Errorf(
						"source %q has %d columns available but %d columns specified",
						srcName, aliasIdx, len(colAlias))
				}
				if src.info.sourceColumns[colIdx].hidden {
					continue
				}
				src.info.sourceColumns[colIdx].Name = string(colAlias[aliasIdx])
				aliasIdx++
			}
		}
		return src, nil

	default:
		return planDataSource{}, errors.Errorf("unsupported FROM type %T", src)
	}
}
예제 #26
0
// TableStats is an endpoint that returns columns, indices, and other
// relevant details for the specified table.
func (s *adminServer) TableStats(
	ctx context.Context, req *serverpb.TableStatsRequest,
) (*serverpb.TableStatsResponse, error) {
	escDBName := parser.Name(req.Database).String()
	if err := s.assertNotVirtualSchema(escDBName); err != nil {
		return nil, err
	}

	// Get table span.
	var tableSpan roachpb.Span
	iexecutor := sql.InternalExecutor{LeaseManager: s.server.leaseMgr}
	if err := s.server.db.Txn(ctx, func(txn *client.Txn) error {
		var err error
		tableSpan, err = iexecutor.GetTableSpan(s.getUser(req), txn, req.Database, req.Table)
		return err
	}); err != nil {
		return nil, s.serverError(err)
	}

	startKey, err := keys.Addr(tableSpan.Key)
	if err != nil {
		return nil, s.serverError(err)
	}
	endKey, err := keys.Addr(tableSpan.EndKey)
	if err != nil {
		return nil, s.serverError(err)
	}

	// Get current range descriptors for table. This is done by scanning over
	// meta2 keys for the range.
	rangeDescKVs, err := s.server.db.Scan(ctx, keys.RangeMetaKey(startKey), keys.RangeMetaKey(endKey), 0)
	if err != nil {
		return nil, s.serverError(err)
	}

	// Extract a list of node IDs from the response.
	nodeIDs := make(map[roachpb.NodeID]struct{})
	for _, kv := range rangeDescKVs {
		var rng roachpb.RangeDescriptor
		if err := kv.Value.GetProto(&rng); err != nil {
			return nil, s.serverError(err)
		}
		for _, repl := range rng.Replicas {
			nodeIDs[repl.NodeID] = struct{}{}
		}
	}

	// Construct TableStatsResponse by sending an RPC to every node involved.
	tableStatResponse := serverpb.TableStatsResponse{
		NodeCount: int64(len(nodeIDs)),
		// TODO(mrtracy): The "RangeCount" returned by TableStats is more
		// accurate than the "RangeCount" returned by TableDetails, because this
		// method always consistently queries the meta2 key range for the table;
		// in contrast, TableDetails uses a method on the DistSender, which
		// queries using a range metadata cache and thus may return stale data
		// for tables that are rapidly splitting. However, one potential
		// *advantage* of using the DistSender is that it will populate the
		// DistSender's range metadata cache in the case where meta2 information
		// for this table is not already present; the query used by TableStats
		// does not populate the DistSender cache. We should consider plumbing
		// TableStats' meta2 query through the DistSender so that it will share
		// the advantage of populating the cache (without the disadvantage of
		// potentially returning stale data).
		// See Github #5435 for some discussion.
		RangeCount: int64(len(rangeDescKVs)),
	}
	type nodeResponse struct {
		nodeID roachpb.NodeID
		resp   *serverpb.SpanStatsResponse
		err    error
	}

	// Send a SpanStats query to each node. Set a timeout on the context for
	// these queries.
	responses := make(chan nodeResponse)
	nodeCtx, cancel := context.WithTimeout(ctx, base.NetworkTimeout)
	defer cancel()
	for nodeID := range nodeIDs {
		nodeID := nodeID
		if err := s.server.stopper.RunAsyncTask(nodeCtx, func(ctx context.Context) {
			var spanResponse *serverpb.SpanStatsResponse
			client, err := s.server.status.dialNode(nodeID)
			if err == nil {
				req := serverpb.SpanStatsRequest{
					StartKey: startKey,
					EndKey:   endKey,
					NodeID:   nodeID.String(),
				}
				spanResponse, err = client.SpanStats(ctx, &req)
			}

			response := nodeResponse{
				nodeID: nodeID,
				resp:   spanResponse,
				err:    err,
			}
			select {
			case responses <- response:
				// Response processed.
			case <-ctx.Done():
				// Context completed, response no longer needed.
			}
		}); err != nil {
			return nil, err
		}
	}
	for remainingResponses := len(nodeIDs); remainingResponses > 0; remainingResponses-- {
		select {
		case resp := <-responses:
			// For nodes which returned an error, note that the node's data
			// is missing. For successful calls, aggregate statistics.
			if resp.err != nil {
				tableStatResponse.MissingNodes = append(
					tableStatResponse.MissingNodes,
					serverpb.TableStatsResponse_MissingNode{
						NodeID:       resp.nodeID.String(),
						ErrorMessage: resp.err.Error(),
					},
				)
			} else {
				tableStatResponse.Stats.Add(resp.resp.TotalStats)
				tableStatResponse.ReplicaCount += int64(resp.resp.RangeCount)
			}
		case <-ctx.Done():
			return nil, ctx.Err()
		}
	}

	return &tableStatResponse, nil
}
예제 #27
0
// TableDetails is an endpoint that returns columns, indices, and other
// relevant details for the specified table.
func (s *adminServer) TableDetails(
	ctx context.Context, req *serverpb.TableDetailsRequest,
) (*serverpb.TableDetailsResponse, error) {
	args := sql.SessionArgs{User: s.getUser(req)}
	session := s.NewSessionForRPC(ctx, args)
	defer session.Finish(s.server.sqlExecutor)

	escDBName := parser.Name(req.Database).String()
	if err := s.assertNotVirtualSchema(escDBName); err != nil {
		return nil, err
	}

	// TODO(cdo): Use real placeholders for the table and database names when we've extended our SQL
	// grammar to allow that.
	escTableName := parser.Name(req.Table).String()
	escQualTable := fmt.Sprintf("%s.%s", escDBName, escTableName)
	query := fmt.Sprintf("SHOW COLUMNS FROM %s; SHOW INDEX FROM %s; SHOW GRANTS ON TABLE %s; SHOW CREATE TABLE %s;",
		escQualTable, escQualTable, escQualTable, escQualTable)
	r := s.server.sqlExecutor.ExecuteStatements(session, query, nil)
	defer r.Close()
	if err := s.firstNotFoundError(r.ResultList); err != nil {
		return nil, grpc.Errorf(codes.NotFound, "%s", err)
	}
	if err := s.checkQueryResults(r.ResultList, 4); err != nil {
		return nil, err
	}

	var resp serverpb.TableDetailsResponse

	// Marshal SHOW COLUMNS result.
	//
	// TODO(cdo): protobuf v3's default behavior for fields with zero values (e.g. empty strings)
	// is to suppress them. So, if protobuf field "foo" is an empty string, "foo" won't show
	// up in the marshalled JSON. I feel that this is counterintuitive, and this should be fixed
	// for our API.
	{
		const (
			fieldCol   = "Field" // column name
			typeCol    = "Type"
			nullCol    = "Null"
			defaultCol = "Default"
		)
		scanner := makeResultScanner(r.ResultList[0].Columns)
		for i, nRows := 0, r.ResultList[0].Rows.Len(); i < nRows; i++ {
			row := r.ResultList[0].Rows.At(i)
			var col serverpb.TableDetailsResponse_Column
			if err := scanner.Scan(row, fieldCol, &col.Name); err != nil {
				return nil, err
			}
			if err := scanner.Scan(row, typeCol, &col.Type); err != nil {
				return nil, err
			}
			if err := scanner.Scan(row, nullCol, &col.Nullable); err != nil {
				return nil, err
			}
			isDefaultNull, err := scanner.IsNull(row, defaultCol)
			if err != nil {
				return nil, err
			}
			if !isDefaultNull {
				if err := scanner.Scan(row, defaultCol, &col.DefaultValue); err != nil {
					return nil, err
				}
			}
			resp.Columns = append(resp.Columns, col)
		}
	}

	// Marshal SHOW INDEX result.
	{
		const (
			nameCol      = "Name"
			uniqueCol    = "Unique"
			seqCol       = "Seq"
			columnCol    = "Column"
			directionCol = "Direction"
			storingCol   = "Storing"
			implicitCol  = "Implicit"
		)
		scanner := makeResultScanner(r.ResultList[1].Columns)
		for i, nRows := 0, r.ResultList[1].Rows.Len(); i < nRows; i++ {
			row := r.ResultList[1].Rows.At(i)
			// Marshal grant, splitting comma-separated privileges into a proper slice.
			var index serverpb.TableDetailsResponse_Index
			if err := scanner.Scan(row, nameCol, &index.Name); err != nil {
				return nil, err
			}
			if err := scanner.Scan(row, uniqueCol, &index.Unique); err != nil {
				return nil, err
			}
			if err := scanner.Scan(row, seqCol, &index.Seq); err != nil {
				return nil, err
			}
			if err := scanner.Scan(row, columnCol, &index.Column); err != nil {
				return nil, err
			}
			if err := scanner.Scan(row, directionCol, &index.Direction); err != nil {
				return nil, err
			}
			if err := scanner.Scan(row, storingCol, &index.Storing); err != nil {
				return nil, err
			}
			if err := scanner.Scan(row, implicitCol, &index.Implicit); err != nil {
				return nil, err
			}
			resp.Indexes = append(resp.Indexes, index)
		}
	}

	// Marshal SHOW GRANTS result.
	{
		const (
			userCol       = "User"
			privilegesCol = "Privileges"
		)
		scanner := makeResultScanner(r.ResultList[2].Columns)
		for i, nRows := 0, r.ResultList[2].Rows.Len(); i < nRows; i++ {
			row := r.ResultList[2].Rows.At(i)
			// Marshal grant, splitting comma-separated privileges into a proper slice.
			var grant serverpb.TableDetailsResponse_Grant
			var privileges string
			if err := scanner.Scan(row, userCol, &grant.User); err != nil {
				return nil, err
			}
			if err := scanner.Scan(row, privilegesCol, &privileges); err != nil {
				return nil, err
			}
			grant.Privileges = strings.Split(privileges, ",")
			resp.Grants = append(resp.Grants, grant)
		}
	}

	// Marshal SHOW CREATE TABLE result.
	{
		const createTableCol = "CreateTable"
		showResult := r.ResultList[3]
		if showResult.Rows.Len() != 1 {
			return nil, s.serverErrorf("CreateTable response not available.")
		}

		scanner := makeResultScanner(showResult.Columns)
		var createStmt string
		if err := scanner.Scan(showResult.Rows.At(0), createTableCol, &createStmt); err != nil {
			return nil, err
		}

		resp.CreateTableStatement = createStmt
	}

	// Get the number of ranges in the table. We get the key span for the table
	// data. Then, we count the number of ranges that make up that key span.
	{
		iexecutor := sql.InternalExecutor{LeaseManager: s.server.leaseMgr}
		var tableSpan roachpb.Span
		if err := s.server.db.Txn(ctx, func(txn *client.Txn) error {
			var err error
			tableSpan, err = iexecutor.GetTableSpan(
				s.getUser(req), txn, req.Database, req.Table,
			)
			return err
		}); err != nil {
			return nil, s.serverError(err)
		}
		tableRSpan := roachpb.RSpan{}
		var err error
		tableRSpan.Key, err = keys.Addr(tableSpan.Key)
		if err != nil {
			return nil, s.serverError(err)
		}
		tableRSpan.EndKey, err = keys.Addr(tableSpan.EndKey)
		if err != nil {
			return nil, s.serverError(err)
		}
		rangeCount, err := s.server.distSender.CountRanges(ctx, tableRSpan)
		if err != nil {
			return nil, s.serverError(err)
		}
		resp.RangeCount = rangeCount
	}

	// Query the descriptor ID and zone configuration for this table.
	{
		path, err := s.queryDescriptorIDPath(session, []string{req.Database, req.Table})
		if err != nil {
			return nil, s.serverError(err)
		}
		resp.DescriptorID = int64(path[2])

		id, zone, zoneExists, err := s.queryZonePath(session, path)
		if err != nil {
			return nil, s.serverError(err)
		}

		if !zoneExists {
			zone = config.DefaultZoneConfig()
		}
		resp.ZoneConfig = zone

		switch id {
		case path[1]:
			resp.ZoneConfigLevel = serverpb.ZoneConfigurationLevel_DATABASE
		case path[2]:
			resp.ZoneConfigLevel = serverpb.ZoneConfigurationLevel_TABLE
		default:
			resp.ZoneConfigLevel = serverpb.ZoneConfigurationLevel_CLUSTER
		}
	}

	return &resp, nil
}
예제 #28
0
// DatabaseDetails is an endpoint that returns grants and a list of table names
// for the specified database.
func (s *adminServer) DatabaseDetails(
	ctx context.Context, req *serverpb.DatabaseDetailsRequest,
) (*serverpb.DatabaseDetailsResponse, error) {
	args := sql.SessionArgs{User: s.getUser(req)}
	session := s.NewSessionForRPC(ctx, args)
	defer session.Finish(s.server.sqlExecutor)

	escDBName := parser.Name(req.Database).String()
	if err := s.assertNotVirtualSchema(escDBName); err != nil {
		return nil, err
	}

	// Placeholders don't work with SHOW statements, so we need to manually
	// escape the database name.
	//
	// TODO(cdo): Use placeholders when they're supported by SHOW.
	query := fmt.Sprintf("SHOW GRANTS ON DATABASE %s; SHOW TABLES FROM %s;", escDBName, escDBName)
	r := s.server.sqlExecutor.ExecuteStatements(session, query, nil)
	defer r.Close()
	if err := s.firstNotFoundError(r.ResultList); err != nil {
		return nil, grpc.Errorf(codes.NotFound, "%s", err)
	}
	if err := s.checkQueryResults(r.ResultList, 2); err != nil {
		return nil, s.serverError(err)
	}

	// Marshal grants.
	var resp serverpb.DatabaseDetailsResponse
	{
		const (
			userCol       = "User"
			privilegesCol = "Privileges"
		)

		scanner := makeResultScanner(r.ResultList[0].Columns)
		for i, nRows := 0, r.ResultList[0].Rows.Len(); i < nRows; i++ {
			row := r.ResultList[0].Rows.At(i)
			// Marshal grant, splitting comma-separated privileges into a proper slice.
			var grant serverpb.DatabaseDetailsResponse_Grant
			var privileges string
			if err := scanner.Scan(row, userCol, &grant.User); err != nil {
				return nil, err
			}
			if err := scanner.Scan(row, privilegesCol, &privileges); err != nil {
				return nil, err
			}
			grant.Privileges = strings.Split(privileges, ",")
			resp.Grants = append(resp.Grants, grant)
		}
	}

	// Marshal table names.
	{
		const tableCol = "Table"
		scanner := makeResultScanner(r.ResultList[1].Columns)
		if a, e := len(r.ResultList[1].Columns), 1; a != e {
			return nil, s.serverErrorf("show tables columns mismatch: %d != expected %d", a, e)
		}
		for i, nRows := 0, r.ResultList[1].Rows.Len(); i < nRows; i++ {
			row := r.ResultList[1].Rows.At(i)
			var tableName string
			if err := scanner.Scan(row, tableCol, &tableName); err != nil {
				return nil, err
			}
			resp.TableNames = append(resp.TableNames, tableName)
		}
	}

	// Query the descriptor ID and zone configuration for this database.
	{
		path, err := s.queryDescriptorIDPath(session, []string{req.Database})
		if err != nil {
			return nil, s.serverError(err)
		}
		resp.DescriptorID = int64(path[1])

		id, zone, zoneExists, err := s.queryZonePath(session, path)
		if err != nil {
			return nil, s.serverError(err)
		}

		if !zoneExists {
			zone = config.DefaultZoneConfig()
		}
		resp.ZoneConfig = zone

		switch id {
		case path[1]:
			resp.ZoneConfigLevel = serverpb.ZoneConfigurationLevel_DATABASE
		default:
			resp.ZoneConfigLevel = serverpb.ZoneConfigurationLevel_CLUSTER
		}
	}

	return &resp, nil
}
예제 #29
0
func (p *planner) validateForeignKey(
	ctx context.Context, srcTable *sqlbase.TableDescriptor, srcIdx *sqlbase.IndexDescriptor,
) error {
	targetTable, err := sqlbase.GetTableDescFromID(p.txn, srcIdx.ForeignKey.Table)
	if err != nil {
		return err
	}
	targetIdx, err := targetTable.FindIndexByID(srcIdx.ForeignKey.Index)
	if err != nil {
		return err
	}

	srcName, err := p.getQualifiedTableName(srcTable)
	if err != nil {
		return err
	}

	targetName, err := p.getQualifiedTableName(targetTable)
	if err != nil {
		return err
	}

	escape := func(s string) string {
		return parser.Name(s).String()
	}

	prefix := len(srcIdx.ColumnNames)
	if p := len(targetIdx.ColumnNames); p < prefix {
		prefix = p
	}

	srcCols, targetCols := make([]string, prefix), make([]string, prefix)
	join, where := make([]string, prefix), make([]string, prefix)

	for i := 0; i < prefix; i++ {
		srcCols[i] = fmt.Sprintf("s.%s", escape(srcIdx.ColumnNames[i]))
		targetCols[i] = fmt.Sprintf("t.%s", escape(targetIdx.ColumnNames[i]))
		join[i] = fmt.Sprintf("(%s = %s OR (%s IS NULL AND %s IS NULL))",
			srcCols[i], targetCols[i], srcCols[i], targetCols[i])
		where[i] = fmt.Sprintf("(%s IS NOT NULL AND %s IS NULL)", srcCols[i], targetCols[i])
	}

	query := fmt.Sprintf(
		`SELECT %s FROM %s@%s AS s LEFT OUTER JOIN %s@%s AS t ON %s WHERE %s LIMIT 1`,
		strings.Join(srcCols, ", "),
		srcName, escape(srcIdx.Name), targetName, escape(targetIdx.Name),
		strings.Join(join, " AND "),
		strings.Join(where, " OR "),
	)

	log.Infof(ctx, "Validating FK %q (%q [%v] -> %q [%v]) with query %q",
		srcIdx.ForeignKey.Name,
		srcTable.Name, srcCols, targetTable.Name, targetCols,
		query,
	)

	values, err := p.queryRows(query)
	if err != nil {
		return err
	}

	if len(values) > 0 {
		var pairs bytes.Buffer
		for i := range values[0] {
			if i > 0 {
				pairs.WriteString(", ")
			}
			pairs.WriteString(fmt.Sprintf("%s=%v", srcIdx.ColumnNames[i], values[0][i]))
		}
		return errors.Errorf("foreign key violation: %q row %s has no match in %q",
			srcTable.Name, pairs.String(), targetTable.Name)
	}
	return nil
}
예제 #30
0
func testAdminAPITableDetailsInner(t *testing.T, dbName, tblName string) {
	s, _, _ := serverutils.StartServer(t, base.TestServerArgs{})
	defer s.Stopper().Stop()
	ts := s.(*TestServer)

	escDBName := parser.Name(dbName).String()
	escTblName := parser.Name(tblName).String()

	ac := log.AmbientContext{Tracer: tracing.NewTracer()}
	ctx, span := ac.AnnotateCtxWithSpan(context.Background(), "test")
	defer span.Finish()

	session := sql.NewSession(
		ctx, sql.SessionArgs{User: security.RootUser}, ts.sqlExecutor, nil, &sql.MemoryMetrics{})
	session.StartUnlimitedMonitor()
	defer session.Finish(ts.sqlExecutor)
	setupQueries := []string{
		fmt.Sprintf("CREATE DATABASE %s", escDBName),
		fmt.Sprintf(`CREATE TABLE %s.%s (
	nulls_allowed INT,
	nulls_not_allowed INT NOT NULL DEFAULT 1000,
	default2 INT DEFAULT 2,
	string_default STRING DEFAULT 'default_string'
)`, escDBName, escTblName),
		fmt.Sprintf("GRANT SELECT ON %s.%s TO readonly", escDBName, escTblName),
		fmt.Sprintf("GRANT SELECT,UPDATE,DELETE ON %s.%s TO app", escDBName, escTblName),
		fmt.Sprintf("CREATE INDEX descIdx ON %s.%s (default2 DESC)", escDBName, escTblName),
	}

	for _, q := range setupQueries {
		res := ts.sqlExecutor.ExecuteStatements(session, q, nil)
		defer res.Close()
		if res.ResultList[0].Err != nil {
			t.Fatalf("error executing '%s': %s", q, res.ResultList[0].Err)
		}
	}

	// Perform API call.
	var resp serverpb.TableDetailsResponse
	url := fmt.Sprintf("databases/%s/tables/%s", dbName, tblName)
	if err := getAdminJSONProto(s, url, &resp); err != nil {
		t.Fatal(err)
	}

	// Verify columns.
	expColumns := []serverpb.TableDetailsResponse_Column{
		{Name: "nulls_allowed", Type: "INT", Nullable: true, DefaultValue: ""},
		{Name: "nulls_not_allowed", Type: "INT", Nullable: false, DefaultValue: "1000"},
		{Name: "default2", Type: "INT", Nullable: true, DefaultValue: "2"},
		{Name: "string_default", Type: "STRING", Nullable: true, DefaultValue: "'default_string'"},
		{Name: "rowid", Type: "INT", Nullable: false, DefaultValue: "unique_rowid()"},
	}
	testutils.SortStructs(expColumns, "Name")
	testutils.SortStructs(resp.Columns, "Name")
	if a, e := len(resp.Columns), len(expColumns); a != e {
		t.Fatalf("# of result columns %d != expected %d (got: %#v)", a, e, resp.Columns)
	}
	for i, a := range resp.Columns {
		e := expColumns[i]
		if a.String() != e.String() {
			t.Fatalf("mismatch at column %d: actual %#v != %#v", i, a, e)
		}
	}

	// Verify grants.
	expGrants := []serverpb.TableDetailsResponse_Grant{
		{User: security.RootUser, Privileges: []string{"ALL"}},
		{User: "******", Privileges: []string{"DELETE", "SELECT", "UPDATE"}},
		{User: "******", Privileges: []string{"SELECT"}},
	}
	testutils.SortStructs(expGrants, "User")
	testutils.SortStructs(resp.Grants, "User")
	if a, e := len(resp.Grants), len(expGrants); a != e {
		t.Fatalf("# of grant columns %d != expected %d (got: %#v)", a, e, resp.Grants)
	}
	for i, a := range resp.Grants {
		e := expGrants[i]
		sort.Strings(a.Privileges)
		sort.Strings(e.Privileges)
		if a.String() != e.String() {
			t.Fatalf("mismatch at index %d: actual %#v != %#v", i, a, e)
		}
	}

	// Verify indexes.
	expIndexes := []serverpb.TableDetailsResponse_Index{
		{Name: "primary", Column: "rowid", Direction: "ASC", Unique: true, Seq: 1},
		{Name: "descIdx", Column: "default2", Direction: "DESC", Unique: false, Seq: 1},
	}
	testutils.SortStructs(expIndexes, "Column")
	testutils.SortStructs(resp.Indexes, "Column")
	for i, a := range resp.Indexes {
		e := expIndexes[i]
		if a.String() != e.String() {
			t.Fatalf("mismatch at index %d: actual %#v != %#v", i, a, e)
		}
	}

	// Verify range count.
	if a, e := resp.RangeCount, int64(1); a != e {
		t.Fatalf("# of ranges %d != expected %d", a, e)
	}

	// Verify Create Table Statement.
	{

		const createTableCol = "CreateTable"
		showCreateTableQuery := fmt.Sprintf("SHOW CREATE TABLE %s.%s", escDBName, escTblName)

		resSet := ts.sqlExecutor.ExecuteStatements(session, showCreateTableQuery, nil)
		defer resSet.Close()
		res := resSet.ResultList[0]
		if res.Err != nil {
			t.Fatalf("error executing '%s': %s", showCreateTableQuery, res.Err)
		}

		scanner := makeResultScanner(res.Columns)
		var createStmt string
		if err := scanner.Scan(res.Rows.At(0), createTableCol, &createStmt); err != nil {
			t.Fatal(err)
		}

		if a, e := resp.CreateTableStatement, createStmt; a != e {
			t.Fatalf("mismatched create table statement; expected %s, got %s", e, a)
		}
	}

	// Verify Descriptor ID.
	path, err := ts.admin.queryDescriptorIDPath(session, []string{dbName, tblName})
	if err != nil {
		t.Fatal(err)
	}
	if a, e := resp.DescriptorID, int64(path[2]); a != e {
		t.Fatalf("table had descriptorID %d, expected %d", a, e)
	}
}