// parsePkTuple parese one pk tuple. func parsePkTuple(tokenizer *sqlparser.Tokenizer) (tuple sqlparser.ValTuple, err error) { // start scanning the list for typ, val := tokenizer.Scan(); typ != ')'; typ, val = tokenizer.Scan() { switch typ { case '-': // handle negative numbers typ2, val2 := tokenizer.Scan() if typ2 != sqlparser.NUMBER { return nil, fmt.Errorf("expecing number after '-'") } num := append(sqlparser.NumVal("-"), val2...) tuple = append(tuple, num) case sqlparser.NUMBER: tuple = append(tuple, sqlparser.NumVal(val)) case sqlparser.NULL: tuple = append(tuple, new(sqlparser.NullVal)) case sqlparser.STRING: decoded := make([]byte, base64.StdEncoding.DecodedLen(len(val))) numDecoded, err := base64.StdEncoding.Decode(decoded, val) if err != nil { return nil, err } tuple = append(tuple, sqlparser.StrVal(decoded[:numDecoded])) default: return nil, fmt.Errorf("syntax error at position: %d", tokenizer.Position) } } return tuple, nil }
// getWhereClause returns a whereClause based on desired upper and lower // bounds for primary key. func (qs *QuerySplitter) getWhereClause(start, end sqltypes.Value) *sqlparser.Where { var startClause *sqlparser.ComparisonExpr var endClause *sqlparser.ComparisonExpr var clauses sqlparser.BoolExpr // No upper or lower bound, just return the where clause of original query if start.IsNull() && end.IsNull() { return qs.sel.Where } pk := &sqlparser.ColName{ Name: []byte(qs.splitColumn), } // splitColumn >= start if !start.IsNull() { startClause = &sqlparser.ComparisonExpr{ Operator: sqlparser.AST_GE, Left: pk, Right: sqlparser.NumVal((start).Raw()), } } // splitColumn < end if !end.IsNull() { endClause = &sqlparser.ComparisonExpr{ Operator: sqlparser.AST_LT, Left: pk, Right: sqlparser.NumVal((end).Raw()), } } if startClause == nil { clauses = endClause } else { if endClause == nil { clauses = startClause } else { // splitColumn >= start AND splitColumn < end clauses = &sqlparser.AndExpr{ Left: startClause, Right: endClause, } } } if qs.sel.Where != nil { clauses = &sqlparser.AndExpr{ Left: qs.sel.Where.Expr, Right: clauses, } } return &sqlparser.Where{ Type: sqlparser.AST_WHERE, Expr: clauses, } }
// pushOrderBy pushes the order by clause to the appropriate routes. // In the case of a join, this is allowed only if the order by columns // match the join order. Otherwise, it's an error. // If column numbers were used to reference the columns, those numbers // are readjusted on push-down to match the numbers of the individual // queries. func pushOrderBy(orderBy sqlparser.OrderBy, bldr builder) error { if orderBy == nil { return nil } routeNumber := 0 for _, order := range orderBy { // Only generator is allowed to change the AST. // If we have to change the order by expression, // we have to build a new node. pushOrder := order var rb *route switch node := order.Expr.(type) { case *sqlparser.ColName: var isLocal bool var err error rb, isLocal, err = bldr.Symtab().Find(node, true) if err != nil { return err } if !isLocal { return errors.New("unsupported: subquery references outer query in order by") } case sqlparser.NumVal: num, err := strconv.ParseInt(string(node), 0, 64) if err != nil { return fmt.Errorf("error parsing order by clause: %s", string(node)) } if num < 1 || num > int64(len(bldr.Symtab().Colsyms)) { return errors.New("order by column number out of range") } colsym := bldr.Symtab().Colsyms[num-1] rb = colsym.Route() // We have to recompute the column number. for num, s := range rb.Colsyms { if s == colsym { pushOrder = &sqlparser.Order{ Expr: sqlparser.NumVal(strconv.AppendInt(nil, int64(num+1), 10)), Direction: order.Direction, } } } if pushOrder == order { panic("unexpected: column not found for order by") } default: return errors.New("unsupported: complex expression in order by") } if rb.Order() < routeNumber { return errors.New("unsupported: complex join and out of sequence order by") } if !rb.IsSingle() { return errors.New("unsupported: scatter and order by") } routeNumber = rb.Order() if err := rb.AddOrder(pushOrder); err != nil { return err } } return nil }
func buildLimitClause(offset, rowcount int64) *sqlparser.Limit { return &sqlparser.Limit{ Offset: sqlparser.NumVal(fmt.Sprintf("%d", offset)), Rowcount: sqlparser.NumVal(fmt.Sprintf("%d", rowcount)), } }
// Wireup performs the wire-up tasks. func (rb *route) Wireup(bldr builder, jt *jointab) error { // Resolve values stored in the builder. var err error switch vals := rb.ERoute.Values.(type) { case *sqlparser.ComparisonExpr: // A comparison expression is stored only if it was an IN clause. // We have to convert it to use a list argutment and resolve values. rb.ERoute.Values, err = rb.procureValues(bldr, jt, vals.Right) if err != nil { return err } vals.Right = sqlparser.ListArg("::" + engine.ListVarName) default: rb.ERoute.Values, err = rb.procureValues(bldr, jt, vals) if err != nil { return err } } // Fix up the AST. _ = sqlparser.Walk(func(node sqlparser.SQLNode) (bool, error) { switch node := node.(type) { case *sqlparser.Select: if len(node.SelectExprs) == 0 { node.SelectExprs = sqlparser.SelectExprs([]sqlparser.SelectExpr{ &sqlparser.NonStarExpr{ Expr: sqlparser.NumVal([]byte{'1'}), }, }) } case *sqlparser.ComparisonExpr: if node.Operator == sqlparser.EqualStr { if exprIsValue(node.Left, rb) && !exprIsValue(node.Right, rb) { node.Left, node.Right = node.Right, node.Left } } } return true, nil }, &rb.Select) // Generate query while simultaneously resolving values. varFormatter := func(buf *sqlparser.TrackedBuffer, node sqlparser.SQLNode) { switch node := node.(type) { case *sqlparser.ColName: if !rb.isLocal(node) { joinVar := jt.Procure(bldr, node, rb.Order()) rb.ERoute.JoinVars[joinVar] = struct{}{} buf.Myprintf("%a", ":"+joinVar) return } case *sqlparser.TableName: node.Name.Format(buf) return } node.Format(buf) } buf := sqlparser.NewTrackedBuffer(varFormatter) varFormatter(buf, &rb.Select) rb.ERoute.Query = buf.ParsedQuery().Query rb.ERoute.FieldQuery = rb.generateFieldQuery(&rb.Select, jt) return nil }