Ejemplo n.º 1
0
func getMatch(node sqlparser.BoolExpr, index *Index) (planID PlanID, values interface{}) {
	switch node := node.(type) {
	case *sqlparser.AndExpr:
		if planID, values = getMatch(node.Left, index); planID != SelectScatter {
			return planID, values
		}
		if planID, values = getMatch(node.Right, index); planID != SelectScatter {
			return planID, values
		}
	case *sqlparser.ParenBoolExpr:
		return getMatch(node.Expr, index)
	case *sqlparser.ComparisonExpr:
		switch node.Operator {
		case "=":
			if !nameMatch(node.Left, index.Column) {
				return SelectScatter, nil
			}
			if !sqlparser.IsValue(node.Right) {
				return SelectScatter, nil
			}
			val, err := sqlparser.AsInterface(node.Right)
			if err != nil {
				return SelectScatter, nil
			}
			if index.Type == ShardKey {
				planID = SelectSingleShardKey
			} else {
				planID = SelectSingleLookup
			}
			return planID, val
		case "in":
			if !nameMatch(node.Left, index.Column) {
				return SelectScatter, nil
			}
			if !sqlparser.IsSimpleTuple(node.Right) {
				return SelectScatter, nil
			}
			val, err := sqlparser.AsInterface(node.Right)
			if err != nil {
				return SelectScatter, nil
			}
			node.Right = sqlparser.ListArg("::_vals")
			if index.Type == ShardKey {
				planID = SelectMultiShardKey
			} else {
				planID = SelectMultiLookup
			}
			return planID, val
		}
	}
	return SelectScatter, nil
}
Ejemplo n.º 2
0
func getMatch(node sqlparser.BoolExpr, col string) (planID PlanID, values interface{}) {
	switch node := node.(type) {
	case *sqlparser.AndExpr:
		if planID, values = getMatch(node.Left, col); planID != SelectScatter {
			return planID, values
		}
		if planID, values = getMatch(node.Right, col); planID != SelectScatter {
			return planID, values
		}
	case *sqlparser.ParenBoolExpr:
		return getMatch(node.Expr, col)
	case *sqlparser.ComparisonExpr:
		switch node.Operator {
		case "=":
			if !nameMatch(node.Left, col) {
				return SelectScatter, nil
			}
			if !sqlparser.IsValue(node.Right) {
				return SelectScatter, nil
			}
			val, err := asInterface(node.Right)
			if err != nil {
				return SelectScatter, nil
			}
			return SelectEqual, val
		case "in":
			if !nameMatch(node.Left, col) {
				return SelectScatter, nil
			}
			if !sqlparser.IsSimpleTuple(node.Right) {
				return SelectScatter, nil
			}
			val, err := asInterface(node.Right)
			if err != nil {
				return SelectScatter, nil
			}
			node.Right = sqlparser.ListArg("::" + ListVarName)
			return SelectIN, val
		}
	}
	return SelectScatter, nil
}
Ejemplo n.º 3
0
// Wireup performs the wire-up tasks.
func (rb *route) Wireup(bldr builder, jt *jointab) error {
	// Resolve values stored in the builder.
	var err error
	switch vals := rb.ERoute.Values.(type) {
	case *sqlparser.ComparisonExpr:
		// A comparison expression is stored only if it was an IN clause.
		// We have to convert it to use a list argutment and resolve values.
		rb.ERoute.Values, err = rb.procureValues(bldr, jt, vals.Right)
		if err != nil {
			return err
		}
		vals.Right = sqlparser.ListArg("::" + engine.ListVarName)
	default:
		rb.ERoute.Values, err = rb.procureValues(bldr, jt, vals)
		if err != nil {
			return err
		}
	}

	// Fix up the AST.
	_ = sqlparser.Walk(func(node sqlparser.SQLNode) (bool, error) {
		switch node := node.(type) {
		case *sqlparser.Select:
			if len(node.SelectExprs) == 0 {
				node.SelectExprs = sqlparser.SelectExprs([]sqlparser.SelectExpr{
					&sqlparser.NonStarExpr{
						Expr: sqlparser.NumVal([]byte{'1'}),
					},
				})
			}
		case *sqlparser.ComparisonExpr:
			if node.Operator == sqlparser.EqualStr {
				if exprIsValue(node.Left, rb) && !exprIsValue(node.Right, rb) {
					node.Left, node.Right = node.Right, node.Left
				}
			}
		}
		return true, nil
	}, &rb.Select)

	// Generate query while simultaneously resolving values.
	varFormatter := func(buf *sqlparser.TrackedBuffer, node sqlparser.SQLNode) {
		switch node := node.(type) {
		case *sqlparser.ColName:
			if !rb.isLocal(node) {
				joinVar := jt.Procure(bldr, node, rb.Order())
				rb.ERoute.JoinVars[joinVar] = struct{}{}
				buf.Myprintf("%a", ":"+joinVar)
				return
			}
		case *sqlparser.TableName:
			node.Name.Format(buf)
			return
		}
		node.Format(buf)
	}
	buf := sqlparser.NewTrackedBuffer(varFormatter)
	varFormatter(buf, &rb.Select)
	rb.ERoute.Query = buf.ParsedQuery().Query
	rb.ERoute.FieldQuery = rb.generateFieldQuery(&rb.Select, jt)
	return nil
}