Ejemplo n.º 1
0
// split splits the query into multiple queries. validateQuery() must return
// nil error before split() is called.
func (qs *QuerySplitter) split(pkMinMax *mproto.QueryResult) []proto.QuerySplit {
	boundaries := qs.getSplitBoundaries(pkMinMax)
	splits := []proto.QuerySplit{}
	// No splits, return the original query as a single split
	if len(boundaries) == 0 {
		split := &proto.QuerySplit{
			Query: *qs.query,
		}
		splits = append(splits, *split)
	} else {
		// Loop through the boundaries and generated modified where clauses
		start := sqltypes.Value{}
		clauses := []*sqlparser.Where{}
		for _, end := range boundaries {
			clauses = append(clauses, qs.getWhereClause(start, end))
			start.Inner = end.Inner
		}
		clauses = append(clauses, qs.getWhereClause(start, sqltypes.Value{}))
		// Generate one split per clause
		for _, clause := range clauses {
			sel := qs.sel
			sel.Where = clause
			q := &proto.BoundQuery{
				Sql:           sqlparser.String(sel),
				BindVariables: qs.query.BindVariables,
			}
			split := &proto.QuerySplit{
				Query:    *q,
				RowCount: qs.rowCount,
			}
			splits = append(splits, *split)
		}
	}
	return splits
}
Ejemplo n.º 2
0
func GetStreamExecPlan(sql string, getTable TableGetter) (plan *ExecPlan, err error) {
	statement, err := sqlparser.Parse(sql)
	if err != nil {
		return nil, err
	}

	plan = &ExecPlan{
		PlanId:    PLAN_SELECT_STREAM,
		FullQuery: GenerateFullQuery(statement),
	}

	switch stmt := statement.(type) {
	case *sqlparser.Select:
		if stmt.Lock != "" {
			return nil, errors.New("select with lock disallowed with streaming")
		}
		tableName, _ := analyzeFrom(stmt.From)
		if tableName != "" {
			plan.setTableInfo(tableName, getTable)
		}

	case *sqlparser.Union:
		// pass
	default:
		return nil, fmt.Errorf("'%v' not allowed for streaming", sqlparser.String(stmt))
	}

	return plan, nil
}
Ejemplo n.º 3
0
// asInterface is similar to sqlparser.AsInterface, but it converts
// numeric and string types to native go types.
func asInterface(node sqlparser.ValExpr) (interface{}, error) {
	switch node := node.(type) {
	case sqlparser.ValTuple:
		vals := make([]interface{}, 0, len(node))
		for _, val := range node {
			v, err := asInterface(val)
			if err != nil {
				return nil, err
			}
			vals = append(vals, v)
		}
		return vals, nil
	case sqlparser.ValArg:
		return string(node), nil
	case sqlparser.ListArg:
		return string(node), nil
	case sqlparser.StrVal:
		return []byte(node), nil
	case sqlparser.NumVal:
		val := string(node)
		signed, err := strconv.ParseInt(val, 0, 64)
		if err == nil {
			return signed, nil
		}
		unsigned, err := strconv.ParseUint(val, 0, 64)
		if err == nil {
			return unsigned, nil
		}
		return nil, err
	case *sqlparser.NullVal:
		return nil, nil
	}
	return nil, fmt.Errorf("unexpected node %v", sqlparser.String(node))
}
Ejemplo n.º 4
0
// Interprets the parsed node and correctly encodes the primary key values.
func encodePKValues(tuple sqlparser.ValTuple, insertid int64) (rowPk []interface{}, newinsertid int64, err error) {
	for _, pkVal := range tuple {
		switch pkVal := pkVal.(type) {
		case sqlparser.StrVal:
			rowPk = append(rowPk, []byte(pkVal))
		case sqlparser.NumVal:
			valstr := string(pkVal)
			if ival, err := strconv.ParseInt(valstr, 0, 64); err == nil {
				rowPk = append(rowPk, ival)
			} else if uval, err := strconv.ParseUint(valstr, 0, 64); err == nil {
				rowPk = append(rowPk, uval)
			} else {
				return nil, insertid, err
			}
		case *sqlparser.NullVal:
			rowPk = append(rowPk, insertid)
			insertid++
		default:
			return nil, insertid, fmt.Errorf("unexpected token: '%v'", sqlparser.String(pkVal))
		}
	}
	return rowPk, insertid, nil
}
Ejemplo n.º 5
0
func TestGetWhereClause(t *testing.T) {
	splitter := &QuerySplitter{}
	sql := "select * from test_table where count > :count"
	statement, _ := sqlparser.Parse(sql)
	splitter.sel, _ = statement.(*sqlparser.Select)
	splitter.pkCol = "id"

	// no boundary case, start = end = nil, should not change the where clause
	nilValue := sqltypes.Value{}
	clause := splitter.getWhereClause(nilValue, nilValue)
	want := " where count > :count"
	got := sqlparser.String(clause)
	if !reflect.DeepEqual(got, want) {
		t.Errorf("incorrect where clause for nil ranges, got:%v, want:%v", got, want)
	}

	// Set lower bound, should add the lower bound condition to where clause
	start, _ := sqltypes.BuildValue(20)
	clause = splitter.getWhereClause(start, nilValue)
	want = " where count > :count and id >= 20"
	got = sqlparser.String(clause)
	if !reflect.DeepEqual(got, want) {
		t.Errorf("incorrect where clause, got:%v, want:%v", got, want)
	}

	// Set upper bound, should add the upper bound condition to where clause
	end, _ := sqltypes.BuildValue(40)
	clause = splitter.getWhereClause(nilValue, end)
	want = " where count > :count and id < 40"
	got = sqlparser.String(clause)
	if !reflect.DeepEqual(got, want) {
		t.Errorf("incorrect where clause, got:%v, want:%v", got, want)
	}

	// Set both bounds, should add two conditions to where clause
	clause = splitter.getWhereClause(start, end)
	want = " where count > :count and id >= 20 and id < 40"
	got = sqlparser.String(clause)
	if !reflect.DeepEqual(got, want) {
		t.Errorf("incorrect where clause, got:%v, want:%v", got, want)
	}

	// Original query with no where clause
	sql = "select * from test_table"
	statement, _ = sqlparser.Parse(sql)
	splitter.sel, _ = statement.(*sqlparser.Select)

	// no boundary case, start = end = nil should return no where clause
	clause = splitter.getWhereClause(nilValue, nilValue)
	want = ""
	got = sqlparser.String(clause)
	if !reflect.DeepEqual(got, want) {
		t.Errorf("incorrect where clause for nil ranges, got:%v, want:%v", got, want)
	}

	// Set both bounds, should add two conditions to where clause
	clause = splitter.getWhereClause(start, end)
	want = " where id >= 20 and id < 40"
	got = sqlparser.String(clause)
	if !reflect.DeepEqual(got, want) {
		t.Errorf("incorrect where clause, got:%v, want:%v", got, want)
	}
}