func getMatch(node sqlparser.BoolExpr, index *Index) (planID PlanID, values interface{}) { switch node := node.(type) { case *sqlparser.AndExpr: if planID, values = getMatch(node.Left, index); planID != SelectScatter { return planID, values } if planID, values = getMatch(node.Right, index); planID != SelectScatter { return planID, values } case *sqlparser.ParenBoolExpr: return getMatch(node.Expr, index) case *sqlparser.ComparisonExpr: switch node.Operator { case "=": if !nameMatch(node.Left, index.Column) { return SelectScatter, nil } if !sqlparser.IsValue(node.Right) { return SelectScatter, nil } val, err := sqlparser.AsInterface(node.Right) if err != nil { return SelectScatter, nil } if index.Type == ShardKey { planID = SelectSingleShardKey } else { planID = SelectSingleLookup } return planID, val case "in": if !nameMatch(node.Left, index.Column) { return SelectScatter, nil } if !sqlparser.IsSimpleTuple(node.Right) { return SelectScatter, nil } val, err := sqlparser.AsInterface(node.Right) if err != nil { return SelectScatter, nil } node.Right = sqlparser.ListArg("::_vals") if index.Type == ShardKey { planID = SelectMultiShardKey } else { planID = SelectMultiLookup } return planID, val } } return SelectScatter, nil }
func getMatch(node sqlparser.BoolExpr, col string) (planID PlanID, values interface{}) { switch node := node.(type) { case *sqlparser.AndExpr: if planID, values = getMatch(node.Left, col); planID != SelectScatter { return planID, values } if planID, values = getMatch(node.Right, col); planID != SelectScatter { return planID, values } case *sqlparser.ParenBoolExpr: return getMatch(node.Expr, col) case *sqlparser.ComparisonExpr: switch node.Operator { case "=": if !nameMatch(node.Left, col) { return SelectScatter, nil } if !sqlparser.IsValue(node.Right) { return SelectScatter, nil } val, err := asInterface(node.Right) if err != nil { return SelectScatter, nil } return SelectEqual, val case "in": if !nameMatch(node.Left, col) { return SelectScatter, nil } if !sqlparser.IsSimpleTuple(node.Right) { return SelectScatter, nil } val, err := asInterface(node.Right) if err != nil { return SelectScatter, nil } node.Right = sqlparser.ListArg("::" + ListVarName) return SelectIN, val } } return SelectScatter, nil }
// Wireup performs the wire-up tasks. func (rb *route) Wireup(bldr builder, jt *jointab) error { // Resolve values stored in the builder. var err error switch vals := rb.ERoute.Values.(type) { case *sqlparser.ComparisonExpr: // A comparison expression is stored only if it was an IN clause. // We have to convert it to use a list argutment and resolve values. rb.ERoute.Values, err = rb.procureValues(bldr, jt, vals.Right) if err != nil { return err } vals.Right = sqlparser.ListArg("::" + engine.ListVarName) default: rb.ERoute.Values, err = rb.procureValues(bldr, jt, vals) if err != nil { return err } } // Fix up the AST. _ = sqlparser.Walk(func(node sqlparser.SQLNode) (bool, error) { switch node := node.(type) { case *sqlparser.Select: if len(node.SelectExprs) == 0 { node.SelectExprs = sqlparser.SelectExprs([]sqlparser.SelectExpr{ &sqlparser.NonStarExpr{ Expr: sqlparser.NumVal([]byte{'1'}), }, }) } case *sqlparser.ComparisonExpr: if node.Operator == sqlparser.EqualStr { if exprIsValue(node.Left, rb) && !exprIsValue(node.Right, rb) { node.Left, node.Right = node.Right, node.Left } } } return true, nil }, &rb.Select) // Generate query while simultaneously resolving values. varFormatter := func(buf *sqlparser.TrackedBuffer, node sqlparser.SQLNode) { switch node := node.(type) { case *sqlparser.ColName: if !rb.isLocal(node) { joinVar := jt.Procure(bldr, node, rb.Order()) rb.ERoute.JoinVars[joinVar] = struct{}{} buf.Myprintf("%a", ":"+joinVar) return } case *sqlparser.TableName: node.Name.Format(buf) return } node.Format(buf) } buf := sqlparser.NewTrackedBuffer(varFormatter) varFormatter(buf, &rb.Select) rb.ERoute.Query = buf.ParsedQuery().Query rb.ERoute.FieldQuery = rb.generateFieldQuery(&rb.Select, jt) return nil }