Example #1
0
// Detect aggregate function or groupby clause.
func (b *planBuilder) detectSelectAgg(sel *ast.SelectStmt) bool {
	if sel.GroupBy != nil {
		return true
	}
	for _, f := range sel.GetResultFields() {
		if ast.HasAggFlag(f.Expr) {
			return true
		}
	}
	return false
}
Example #2
0
// Enter implements ast.Visitor interface.
func (nr *nameResolver) Enter(inNode ast.Node) (outNode ast.Node, skipChildren bool) {
	switch v := inNode.(type) {
	case *ast.AdminStmt:
		nr.pushContext()
	case *ast.AggregateFuncExpr:
		ctx := nr.currentContext()
		if ctx.inHaving {
			ctx.inHavingAgg = true
		}
	case *ast.ByItem:
		if _, ok := v.Expr.(*ast.ColumnNameExpr); !ok {
			// If ByItem is not a single column name expression,
			// the resolving rule is different from order by clause.
			nr.currentContext().inByItemExpression = true
		}
		if nr.currentContext().inGroupBy {
			// make sure item is not aggregate function
			if ast.HasAggFlag(v.Expr) {
				nr.Err = ErrInvalidGroupFuncUse
				return inNode, true
			}
		}
	case *ast.DeleteStmt:
		nr.pushContext()
	case *ast.DeleteTableList:
		nr.currentContext().inDeleteTableList = true
	case *ast.FieldList:
		nr.currentContext().inFieldList = true
	case *ast.GroupByClause:
		nr.currentContext().inGroupBy = true
	case *ast.HavingClause:
		nr.currentContext().inHaving = true
	case *ast.InsertStmt:
		nr.pushContext()
	case *ast.Join:
		nr.pushJoin(v)
	case *ast.OnCondition:
		nr.currentContext().inOnCondition = true
	case *ast.OrderByClause:
		nr.currentContext().inOrderBy = true
	case *ast.SelectStmt:
		nr.pushContext()
	case *ast.TableRefsClause:
		nr.currentContext().inTableRefs = true
	case *ast.UnionStmt:
		nr.pushContext()
	case *ast.UpdateStmt:
		nr.pushContext()
	}
	return inNode, false
}
Example #3
0
// Detect aggregate function or groupby clause.
func (b *planBuilder) detectSelectAgg(sel *ast.SelectStmt) bool {
	if sel.GroupBy != nil {
		return true
	}
	for _, f := range sel.GetResultFields() {
		if ast.HasAggFlag(f.Expr) {
			return true
		}
	}
	if sel.Having != nil {
		if ast.HasAggFlag(sel.Having.Expr) {
			return true
		}
	}
	if sel.OrderBy != nil {
		for _, item := range sel.OrderBy.Items {
			if ast.HasAggFlag(item.Expr) {
				return true
			}
		}
	}
	return false
}
Example #4
0
func (nr *nameResolver) handlePosition(pos *ast.PositionExpr) {
	ctx := nr.currentContext()
	if pos.N < 1 || pos.N > len(ctx.fieldList) {
		nr.Err = errors.Errorf("Unknown column '%d'", pos.N)
		return
	}
	pos.Refer = ctx.fieldList[pos.N-1]
	if nr.currentContext().inGroupBy {
		// make sure item is not aggregate function
		if ast.HasAggFlag(pos.Refer.Expr) {
			nr.Err = errors.New("group by cannot contain aggregate function")
		}
	}
}
Example #5
0
func (ts *testFlagSuite) TestHasAggFlag(c *C) {
	expr := &ast.BetweenExpr{}
	cases := []struct {
		flag   uint64
		hasAgg bool
	}{
		{ast.FlagHasAggregateFunc, true},
		{ast.FlagHasAggregateFunc | ast.FlagHasVariable, true},
		{ast.FlagHasVariable, false},
	}
	for _, ca := range cases {
		expr.SetFlag(ca.flag)
		c.Assert(ast.HasAggFlag(expr), Equals, ca.hasAgg)
	}
}
Example #6
0
func (nr *nameResolver) handlePosition(pos *ast.PositionExpr) {
	ctx := nr.currentContext()
	if pos.N < 1 || pos.N > len(ctx.fieldList) {
		nr.Err = errors.Errorf("Unknown column '%d'", pos.N)
		return
	}
	matched := ctx.fieldList[pos.N-1]
	nf := *matched
	expr := matched.Expr
	if cexpr, ok := expr.(*ast.ColumnNameExpr); ok {
		expr = cexpr.Refer.Expr
	}
	nf.Expr = expr
	pos.Refer = &nf
	if nr.currentContext().inGroupBy {
		// make sure item is not aggregate function
		if ast.HasAggFlag(pos.Refer.Expr) {
			nr.Err = errors.New("group by cannot contain aggregate function")
		}
	}
}
Example #7
0
// Enter implements ast.Visitor interface.
func (nr *nameResolver) Enter(inNode ast.Node) (outNode ast.Node, skipChildren bool) {
	switch v := inNode.(type) {
	case *ast.AdminStmt:
		nr.pushContext()
	case *ast.AggregateFuncExpr:
		ctx := nr.currentContext()
		if ctx.inHaving {
			ctx.inHavingAgg = true
		}
	case *ast.AlterTableStmt:
		nr.pushContext()
	case *ast.AnalyzeTableStmt:
		nr.pushContext()
	case *ast.ByItem:
		if _, ok := v.Expr.(*ast.ColumnNameExpr); !ok {
			// If ByItem is not a single column name expression,
			// the resolving rule is different from order by clause.
			nr.currentContext().inByItemExpression = true
		}
		if nr.currentContext().inGroupBy {
			// make sure item is not aggregate function
			if ast.HasAggFlag(v.Expr) {
				nr.Err = ErrInvalidGroupFuncUse
				return inNode, true
			}
		}
	case *ast.CreateIndexStmt:
		nr.pushContext()
	case *ast.CreateTableStmt:
		nr.pushContext()
		nr.currentContext().inCreateOrDropTable = true
	case *ast.DeleteStmt:
		nr.pushContext()
	case *ast.DeleteTableList:
		nr.currentContext().inDeleteTableList = true
	case *ast.DoStmt:
		nr.pushContext()
	case *ast.DropTableStmt:
		nr.pushContext()
		nr.currentContext().inCreateOrDropTable = true
	case *ast.DropIndexStmt:
		nr.pushContext()
	case *ast.FieldList:
		nr.currentContext().inFieldList = true
	case *ast.GroupByClause:
		nr.currentContext().inGroupBy = true
	case *ast.HavingClause:
		nr.currentContext().inHaving = true
	case *ast.InsertStmt:
		nr.pushContext()
	case *ast.Join:
		nr.pushJoin(v)
	case *ast.OnCondition:
		nr.currentContext().inOnCondition = true
	case *ast.OrderByClause:
		nr.currentContext().inOrderBy = true
	case *ast.SelectStmt:
		nr.pushContext()
	case *ast.SetStmt:
		for _, assign := range v.Variables {
			if cn, ok := assign.Value.(*ast.ColumnNameExpr); ok && cn.Name.Table.L == "" {
				// Convert column name expression to string value expression.
				assign.Value = ast.NewValueExpr(cn.Name.Name.O)
			}
		}
		nr.pushContext()
	case *ast.ShowStmt:
		nr.pushContext()
		nr.currentContext().inShow = true
		nr.fillShowFields(v)
	case *ast.TableRefsClause:
		nr.currentContext().inTableRefs = true
	case *ast.TruncateTableStmt:
		nr.pushContext()
	case *ast.UnionStmt:
		nr.pushContext()
	case *ast.UpdateStmt:
		nr.pushContext()
	}
	return inNode, false
}