Exemplo n.º 1
0
// split splits the query into multiple queries. validateQuery() must return
// nil error before split() is called.
func (qs *QuerySplitter) split(pkMinMax *mproto.QueryResult) []proto.QuerySplit {
	boundaries := qs.getSplitBoundaries(pkMinMax)
	splits := []proto.QuerySplit{}
	// No splits, return the original query as a single split
	if len(boundaries) == 0 {
		split := &proto.QuerySplit{
			Query: *qs.query,
		}
		splits = append(splits, *split)
	} else {
		// Loop through the boundaries and generated modified where clauses
		start := sqltypes.Value{}
		clauses := []*sqlparser.Where{}
		for _, end := range boundaries {
			clauses = append(clauses, qs.getWhereClause(start, end))
			start.Inner = end.Inner
		}
		clauses = append(clauses, qs.getWhereClause(start, sqltypes.Value{}))
		// Generate one split per clause
		for _, clause := range clauses {
			sel := qs.sel
			sel.Where = clause
			q := &proto.BoundQuery{
				Sql:           sqlparser.String(sel),
				BindVariables: qs.query.BindVariables,
			}
			split := &proto.QuerySplit{
				Query:    *q,
				RowCount: qs.rowCount,
			}
			splits = append(splits, *split)
		}
	}
	return splits
}
Exemplo n.º 2
0
// SetMysqlStats receives the values found in the mysql information_schema.tables table
func (ta *Table) SetMysqlStats(tr, dl, il, df, mdl sqltypes.Value) {
	v, _ := tr.ParseInt64()
	ta.TableRows.Set(v)
	v, _ = dl.ParseInt64()
	ta.DataLength.Set(v)
	v, _ = il.ParseInt64()
	ta.IndexLength.Set(v)
	v, _ = df.ParseInt64()
	ta.DataFree.Set(v)
	v, _ = mdl.ParseInt64()
	ta.MaxDataLength.Set(v)
}
Exemplo n.º 3
0
// AddColumn adds a column to the Table.
func (ta *Table) AddColumn(name string, columnType querypb.Type, defval sqltypes.Value, extra string) {
	index := len(ta.Columns)
	ta.Columns = append(ta.Columns, TableColumn{Name: strings.ToLower(name)})
	ta.Columns[index].Type = columnType
	if extra == "auto_increment" {
		ta.Columns[index].IsAuto = true
		// Ignore default value, if any
		return
	}
	if defval.IsNull() {
		return
	}
	// Schema values are trusted.
	ta.Columns[index].Default = sqltypes.MakeTrusted(ta.Columns[index].Type, defval.Raw())
}
Exemplo n.º 4
0
// Convert takes a type and a value, and returns the type:
// - nil for NULL value
// - uint64 for unsigned BIGINT values
// - int64 for all other integer values (signed and unsigned)
// - float64 for floating point values that fit in a float
// - []byte for everything else
func Convert(field *querypb.Field, val sqltypes.Value) (interface{}, error) {
	if field.Type == sqltypes.Null {
		return nil, nil
	} else if sqltypes.IsSigned(field.Type) {
		return strconv.ParseInt(val.String(), 0, 64)
	} else if sqltypes.IsUnsigned(field.Type) {
		return strconv.ParseUint(val.String(), 0, 64)
	} else if sqltypes.IsFloat(field.Type) {
		return strconv.ParseFloat(val.String(), 64)
	}
	return val.Raw(), nil
}
Exemplo n.º 5
0
// Convert takes a type and a value, and returns the type:
// - nil for NULL value
// - int64 for integer number types that fit in 64 bits
//   (signed or unsigned are all converted to signed)
// - float64 for floating point values that fit in a float
// - []byte for everything else
func Convert(mysqlType int64, val sqltypes.Value) (interface{}, error) {
	if val.IsNull() {
		return nil, nil
	}

	switch mysqlType {
	case VT_TINY, VT_SHORT, VT_LONG, VT_LONGLONG, VT_INT24:
		return strconv.ParseInt(val.String(), 0, 64)
	case VT_FLOAT, VT_DOUBLE:
		return strconv.ParseFloat(val.String(), 64)
	}
	return val.Raw(), nil
}
Exemplo n.º 6
0
// getWhereClause returns a whereClause based on desired upper and lower
// bounds for primary key.
func (qs *QuerySplitter) getWhereClause(start, end sqltypes.Value) *sqlparser.Where {
	var startClause *sqlparser.ComparisonExpr
	var endClause *sqlparser.ComparisonExpr
	var clauses sqlparser.BoolExpr
	// No upper or lower bound, just return the where clause of original query
	if start.IsNull() && end.IsNull() {
		return qs.sel.Where
	}
	pk := &sqlparser.ColName{
		Name: []byte(qs.splitColumn),
	}
	// splitColumn >= start
	if !start.IsNull() {
		startClause = &sqlparser.ComparisonExpr{
			Operator: sqlparser.AST_GE,
			Left:     pk,
			Right:    sqlparser.NumVal((start).Raw()),
		}
	}
	// splitColumn < end
	if !end.IsNull() {
		endClause = &sqlparser.ComparisonExpr{
			Operator: sqlparser.AST_LT,
			Left:     pk,
			Right:    sqlparser.NumVal((end).Raw()),
		}
	}
	if startClause == nil {
		clauses = endClause
	} else {
		if endClause == nil {
			clauses = startClause
		} else {
			// splitColumn >= start AND splitColumn < end
			clauses = &sqlparser.AndExpr{
				Left:  startClause,
				Right: endClause,
			}
		}
	}
	if qs.sel.Where != nil {
		clauses = &sqlparser.AndExpr{
			Left:  qs.sel.Where.Expr,
			Right: clauses,
		}
	}
	return &sqlparser.Where{
		Type: sqlparser.AST_WHERE,
		Expr: clauses,
	}
}
Exemplo n.º 7
0
func (ti *TableInfo) computePrefix(conn PoolConnection, createTime sqltypes.Value, hashRegistry map[string]string) string {
	if createTime.IsNull() {
		relog.Warning("%s has no time stamp. Will not be cached.", ti.Name)
		return ""
	}
	createTable, err := conn.ExecuteFetch(fmt.Sprintf("show create table %s", ti.Name), 10000, false)
	if err != nil {
		relog.Warning("Couldnt read table info: %v", err)
		return ""
	}
	// Normalize & remove auto_increment because it changes on every insert
	norm1 := strings.ToLower(createTable.Rows[0][1].String())
	norm2 := autoIncr.ReplaceAllLiteralString(norm1, "")
	thash := base64fnv(norm2 + createTime.String())
	if _, ok := hashRegistry[thash]; ok {
		relog.Warning("Hash collision for %s (schema revert?). Will not be cached", ti.Name)
		return ""
	}
	hashRegistry[thash] = ti.Name
	return thash
}
Exemplo n.º 8
0
// split splits the query into multiple queries. validateQuery() must return
// nil error before split() is called.
func (qs *QuerySplitter) split(columnType int64, pkMinMax *mproto.QueryResult) ([]proto.QuerySplit, error) {
	boundaries, err := qs.splitBoundaries(columnType, pkMinMax)
	if err != nil {
		return nil, err
	}
	splits := []proto.QuerySplit{}
	// No splits, return the original query as a single split
	if len(boundaries) == 0 {
		split := &proto.QuerySplit{
			Query: *qs.query,
		}
		splits = append(splits, *split)
	} else {
		boundaries = append(boundaries, sqltypes.Value{})
		whereClause := qs.sel.Where
		// Loop through the boundaries and generated modified where clauses
		start := sqltypes.Value{}
		for _, end := range boundaries {
			bindVars := make(map[string]interface{}, len(qs.query.BindVariables))
			for k, v := range qs.query.BindVariables {
				bindVars[k] = v
			}
			qs.sel.Where = qs.getWhereClause(whereClause, bindVars, start, end)
			q := &proto.BoundQuery{
				Sql:           sqlparser.String(qs.sel),
				BindVariables: bindVars,
			}
			split := &proto.QuerySplit{
				Query:    *q,
				RowCount: qs.rowCount,
			}
			splits = append(splits, *split)
			start.Inner = end.Inner
		}
		qs.sel.Where = whereClause // reset where clause
	}
	return splits, err
}
Exemplo n.º 9
0
func (si *SchemaInfo) updateLastChange(createTime sqltypes.Value) {
	if createTime.IsNull() {
		return
	}
	t, err := strconv.ParseInt(createTime.String(), 10, 64)
	if err != nil {
		relog.Warning("Could not parse time %s: %v", createTime.String(), err)
		return
	}
	if si.lastChange.Unix() < t {
		si.lastChange = time.Unix(t, 0)
	}
}
Exemplo n.º 10
0
func validateValue(col *schema.TableColumn, value sqltypes.Value) {
	if value.IsNull() {
		return
	}
	switch col.Category {
	case schema.CAT_NUMBER:
		if !value.IsNumeric() {
			panic(NewTabletError(FAIL, "Type mismatch, expecting numeric type for %v", value))
		}
	case schema.CAT_VARBINARY:
		if !value.IsString() {
			panic(NewTabletError(FAIL, "Type mismatch, expecting string type for %v", value))
		}
	}
}
Exemplo n.º 11
0
func validateValue(col *schema.TableColumn, value sqltypes.Value) error {
	if value.IsNull() {
		return nil
	}
	if sqltypes.IsIntegral(col.Type) {
		if !value.IsNumeric() {
			return NewTabletError(ErrFail, vtrpc.ErrorCode_BAD_INPUT, "type mismatch, expecting numeric type for %v for column: %v", value, col)
		}
	} else if col.Type == sqltypes.VarBinary {
		if !value.IsString() {
			return NewTabletError(ErrFail, vtrpc.ErrorCode_BAD_INPUT, "type mismatch, expecting string type for %v for column: %v", value, col)
		}
	}
	return nil
}
Exemplo n.º 12
0
func validateValue(col *schema.TableColumn, value sqltypes.Value) error {
	if value.IsNull() {
		return nil
	}
	switch col.Category {
	case schema.CAT_NUMBER:
		if !value.IsNumeric() {
			return NewTabletError(ErrFail, "type mismatch, expecting numeric type for %v for column: %v", value, col)
		}
	case schema.CAT_VARBINARY:
		if !value.IsString() {
			return NewTabletError(ErrFail, "type mismatch, expecting string type for %v for column: %v", value, col)
		}
	}
	return nil
}
Exemplo n.º 13
0
// AddColumn adds a column to the Table.
func (ta *Table) AddColumn(name string, columnType querypb.Type, defval sqltypes.Value, extra string) {
	index := len(ta.Columns)
	ta.Columns = append(ta.Columns, TableColumn{Name: name})
	ta.Columns[index].Type = columnType
	if extra == "auto_increment" {
		ta.Columns[index].IsAuto = true
		// Ignore default value, if any
		return
	}
	if defval.IsNull() {
		return
	}
	if sqltypes.IsIntegral(ta.Columns[index].Type) {
		ta.Columns[index].Default = sqltypes.MakeNumeric(defval.Raw())
	} else {
		ta.Columns[index].Default = sqltypes.MakeString(defval.Raw())
	}
}
Exemplo n.º 14
0
// valueToBigRat converts a numeric 'value' regarded as having type 'valueType' into a
// big.Rat object.
// Note:
// We use an explicit valueType rather than depend on the type stored in 'value' to force
// the type of MAX(column) or MIN(column) to correspond to the type of 'column'.
// (We've had issues where the type of MAX(column) returned by Vitess was signed even if the
// type of column was unsigned).
func valueToBigRat(value sqltypes.Value, valueType querypb.Type) (*big.Rat, error) {
	switch {
	case sqltypes.IsUnsigned(valueType):
		nativeValue, err := value.ParseUint64()
		if err != nil {
			return nil, err
		}
		return uint64ToBigRat(nativeValue), nil
	case sqltypes.IsSigned(valueType):
		nativeValue, err := value.ParseInt64()
		if err != nil {
			return nil, err
		}
		return int64ToBigRat(nativeValue), nil
	case sqltypes.IsFloat(valueType):
		nativeValue, err := value.ParseFloat64()
		if err != nil {
			return nil, err
		}
		return float64ToBigRat(nativeValue), nil
	default:
		panic(fmt.Sprintf("got value with a non numeric type: %v", value))
	}
}
Exemplo n.º 15
0
func (ta *Table) AddColumn(name string, columnType string, defval sqltypes.Value, extra string) {
	index := len(ta.Columns)
	ta.Columns = append(ta.Columns, TableColumn{Name: name})
	if strings.Contains(columnType, "int") {
		ta.Columns[index].Category = CAT_NUMBER
	} else if strings.HasPrefix(columnType, "varbinary") {
		ta.Columns[index].Category = CAT_VARBINARY
	} else {
		ta.Columns[index].Category = CAT_OTHER
	}
	if extra == "auto_increment" {
		ta.Columns[index].IsAuto = true
		// Ignore default value, if any
		return
	}
	if defval.IsNull() {
		return
	}
	if ta.Columns[index].Category == CAT_NUMBER {
		ta.Columns[index].Default = sqltypes.MakeNumeric(defval.Raw())
	} else {
		ta.Columns[index].Default = sqltypes.MakeString(defval.Raw())
	}
}
Exemplo n.º 16
0
// getWhereClause returns a whereClause based on desired upper and lower
// bounds for primary key.
func (qs *QuerySplitter) getWhereClause(whereClause *sqlparser.Where, bindVars map[string]interface{}, start, end sqltypes.Value) *sqlparser.Where {
	var startClause *sqlparser.ComparisonExpr
	var endClause *sqlparser.ComparisonExpr
	var clauses sqlparser.BoolExpr
	// No upper or lower bound, just return the where clause of original query
	if start.IsNull() && end.IsNull() {
		return whereClause
	}
	pk := &sqlparser.ColName{
		Name: sqlparser.SQLName(qs.splitColumn),
	}
	if !start.IsNull() {
		startClause = &sqlparser.ComparisonExpr{
			Operator: sqlparser.GreaterEqualStr,
			Left:     pk,
			Right:    sqlparser.ValArg([]byte(":" + startBindVarName)),
		}
		bindVars[startBindVarName] = start.ToNative()
	}
	// splitColumn < end
	if !end.IsNull() {
		endClause = &sqlparser.ComparisonExpr{
			Operator: sqlparser.LessThanStr,
			Left:     pk,
			Right:    sqlparser.ValArg([]byte(":" + endBindVarName)),
		}
		bindVars[endBindVarName] = end.ToNative()
	}
	if startClause == nil {
		clauses = endClause
	} else {
		if endClause == nil {
			clauses = startClause
		} else {
			// splitColumn >= start AND splitColumn < end
			clauses = &sqlparser.AndExpr{
				Left:  startClause,
				Right: endClause,
			}
		}
	}
	if whereClause != nil {
		clauses = &sqlparser.AndExpr{
			Left:  &sqlparser.ParenBoolExpr{Expr: whereClause.Expr},
			Right: &sqlparser.ParenBoolExpr{Expr: clauses},
		}
	}
	return &sqlparser.Where{
		Type: sqlparser.WhereStr,
		Expr: clauses,
	}
}
Exemplo n.º 17
0
// Convert takes a type and a value, and returns the type:
// - nil for NULL value
// - uint64 for unsigned BIGINT values
// - int64 for all other integer values (signed and unsigned)
// - float64 for floating point values that fit in a float
// - []byte for everything else
func Convert(field Field, val sqltypes.Value) (interface{}, error) {
	if val.IsNull() {
		return nil, nil
	}

	switch field.Type {
	case VT_LONGLONG:
		if field.Flags&VT_UNSIGNED_FLAG == VT_UNSIGNED_FLAG {
			return strconv.ParseUint(val.String(), 0, 64)
		}
		return strconv.ParseInt(val.String(), 0, 64)
	case VT_TINY, VT_SHORT, VT_LONG, VT_INT24:
		// Regardless of whether UNSIGNED_FLAG is set in field.Flags, we map all
		// signed and unsigned values to a signed Go type because
		// - Go doesn't officially support uint64 in their SQL interface
		// - there is no loss of the value
		// The only exception we make are for unsigned BIGINTs, see VT_LONGLONG above.
		return strconv.ParseInt(val.String(), 0, 64)
	case VT_FLOAT, VT_DOUBLE:
		return strconv.ParseFloat(val.String(), 64)
	}
	return val.Raw(), nil
}
Exemplo n.º 18
0
// UnmarshalBson bson-decodes into StreamEvent.
func (streamEvent *StreamEvent) UnmarshalBson(buf *bytes.Buffer, kind byte) {
	switch kind {
	case bson.EOO, bson.Object:
		// valid
	case bson.Null:
		return
	default:
		panic(bson.NewBsonError("unexpected kind %v for StreamEvent", kind))
	}
	bson.Next(buf, 4)

	for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
		switch bson.ReadCString(buf) {
		case "Category":
			streamEvent.Category = bson.DecodeString(buf, kind)
		case "TableName":
			streamEvent.TableName = bson.DecodeString(buf, kind)
		case "PrimaryKeyFields":
			// []mproto.Field
			if kind != bson.Null {
				if kind != bson.Array {
					panic(bson.NewBsonError("unexpected kind %v for streamEvent.PrimaryKeyFields", kind))
				}
				bson.Next(buf, 4)
				streamEvent.PrimaryKeyFields = make([]mproto.Field, 0, 8)
				for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
					bson.SkipIndex(buf)
					var _v1 mproto.Field
					_v1.UnmarshalBson(buf, kind)
					streamEvent.PrimaryKeyFields = append(streamEvent.PrimaryKeyFields, _v1)
				}
			}
		case "PrimaryKeyValues":
			// [][]sqltypes.Value
			if kind != bson.Null {
				if kind != bson.Array {
					panic(bson.NewBsonError("unexpected kind %v for streamEvent.PrimaryKeyValues", kind))
				}
				bson.Next(buf, 4)
				streamEvent.PrimaryKeyValues = make([][]sqltypes.Value, 0, 8)
				for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
					bson.SkipIndex(buf)
					var _v2 []sqltypes.Value
					// []sqltypes.Value
					if kind != bson.Null {
						if kind != bson.Array {
							panic(bson.NewBsonError("unexpected kind %v for _v2", kind))
						}
						bson.Next(buf, 4)
						_v2 = make([]sqltypes.Value, 0, 8)
						for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
							bson.SkipIndex(buf)
							var _v3 sqltypes.Value
							_v3.UnmarshalBson(buf, kind)
							_v2 = append(_v2, _v3)
						}
					}
					streamEvent.PrimaryKeyValues = append(streamEvent.PrimaryKeyValues, _v2)
				}
			}
		case "Sql":
			streamEvent.Sql = bson.DecodeString(buf, kind)
		case "Timestamp":
			streamEvent.Timestamp = bson.DecodeInt64(buf, kind)
		case "TransactionID":
			streamEvent.TransactionID = bson.DecodeString(buf, kind)
		default:
			bson.Skip(buf, kind)
		}
	}
}
Exemplo n.º 19
0
// SetMysqlStats receives the values found in the mysql information_schema.tables table
func (ta *Table) SetMysqlStats(tr, dl, il, df sqltypes.Value) {
	ta.TableRows, _ = tr.ParseInt64()
	ta.DataLength, _ = dl.ParseInt64()
	ta.IndexLength, _ = il.ParseInt64()
	ta.DataFree, _ = df.ParseInt64()
}
Exemplo n.º 20
0
// UnmarshalBson bson-decodes into QueryResult.
func (queryResult *QueryResult) UnmarshalBson(buf *bytes.Buffer, kind byte) {
	switch kind {
	case bson.EOO, bson.Object:
		// valid
	case bson.Null:
		return
	default:
		panic(bson.NewBsonError("unexpected kind %v for QueryResult", kind))
	}
	bson.Next(buf, 4)

	for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
		switch bson.ReadCString(buf) {
		case "Fields":
			// []Field
			if kind != bson.Null {
				if kind != bson.Array {
					panic(bson.NewBsonError("unexpected kind %v for queryResult.Fields", kind))
				}
				bson.Next(buf, 4)
				queryResult.Fields = make([]Field, 0, 8)
				for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
					bson.SkipIndex(buf)
					var _v1 Field
					_v1.UnmarshalBson(buf, kind)
					queryResult.Fields = append(queryResult.Fields, _v1)
				}
			}
		case "RowsAffected":
			queryResult.RowsAffected = bson.DecodeUint64(buf, kind)
		case "InsertId":
			queryResult.InsertId = bson.DecodeUint64(buf, kind)
		case "Rows":
			// [][]sqltypes.Value
			if kind != bson.Null {
				if kind != bson.Array {
					panic(bson.NewBsonError("unexpected kind %v for queryResult.Rows", kind))
				}
				bson.Next(buf, 4)
				queryResult.Rows = make([][]sqltypes.Value, 0, 8)
				for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
					bson.SkipIndex(buf)
					var _v2 []sqltypes.Value
					// []sqltypes.Value
					if kind != bson.Null {
						if kind != bson.Array {
							panic(bson.NewBsonError("unexpected kind %v for _v2", kind))
						}
						bson.Next(buf, 4)
						_v2 = make([]sqltypes.Value, 0, 8)
						for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
							bson.SkipIndex(buf)
							var _v3 sqltypes.Value
							_v3.UnmarshalBson(buf, kind)
							_v2 = append(_v2, _v3)
						}
					}
					queryResult.Rows = append(queryResult.Rows, _v2)
				}
			}
		default:
			bson.Skip(buf, kind)
		}
	}
}
Exemplo n.º 21
0
// UnmarshalBson bson-decodes into QueryResult.
func (queryResult *QueryResult) UnmarshalBson(buf *bytes.Buffer, kind byte) {
	switch kind {
	case bson.EOO, bson.Object:
		// valid
	case bson.Null:
		return
	default:
		panic(bson.NewBsonError("unexpected kind %v for QueryResult", kind))
	}
	bson.Next(buf, 4)

	for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
		switch bson.ReadCString(buf) {
		case "Fields":
			// []*query.Field
			if kind != bson.Null {
				if kind != bson.Array {
					panic(bson.NewBsonError("unexpected kind %v for queryResult.Fields", kind))
				}
				bson.Next(buf, 4)
				queryResult.Fields = make([]*querypb.Field, 0, 8)
				var f BSONField
				for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
					bson.SkipIndex(buf)
					var _v1 *querypb.Field
					// *query.Field
					_v1 = new(querypb.Field)
					bson.UnmarshalFromBuffer(buf, &f)
					_v1.Name = f.Name
					_v1.Type = sqltypes.MySQLToType(f.Type, f.Flags)
					queryResult.Fields = append(queryResult.Fields, _v1)
				}
			}
		case "RowsAffected":
			queryResult.RowsAffected = bson.DecodeUint64(buf, kind)
		case "InsertId":
			queryResult.InsertId = bson.DecodeUint64(buf, kind)
		case "Rows":
			// [][]sqltypes.Value
			if kind != bson.Null {
				if kind != bson.Array {
					panic(bson.NewBsonError("unexpected kind %v for queryResult.Rows", kind))
				}
				bson.Next(buf, 4)
				queryResult.Rows = make([][]sqltypes.Value, 0, 8)
				for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
					bson.SkipIndex(buf)
					var _v2 []sqltypes.Value
					// []sqltypes.Value
					if kind != bson.Null {
						if kind != bson.Array {
							panic(bson.NewBsonError("unexpected kind %v for _v2", kind))
						}
						bson.Next(buf, 4)
						_v2 = make([]sqltypes.Value, 0, 8)
						for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
							bson.SkipIndex(buf)
							var _v3 sqltypes.Value
							_v3.UnmarshalBson(buf, kind)
							_v2 = append(_v2, _v3)
						}
					}
					queryResult.Rows = append(queryResult.Rows, _v2)
				}
			}
		case "Err":
			// *RPCError
			if kind != bson.Null {
				queryResult.Err = new(RPCError)
				(*queryResult.Err).UnmarshalBson(buf, kind)
			}
		default:
			bson.Skip(buf, kind)
		}
	}
}
Exemplo n.º 22
0
// getWhereClause returns a whereClause based on desired upper and lower
// bounds for primary key.
func (qs *QuerySplitter) getWhereClause(whereClause *sqlparser.Where, bindVars map[string]interface{}, start, end sqltypes.Value) *sqlparser.Where {
	var startClause *sqlparser.ComparisonExpr
	var endClause *sqlparser.ComparisonExpr
	var clauses sqlparser.BoolExpr
	// No upper or lower bound, just return the where clause of original query
	if start.IsNull() && end.IsNull() {
		return whereClause
	}
	pk := &sqlparser.ColName{
		Name: sqlparser.SQLName(qs.splitColumn),
	}
	if !start.IsNull() {
		startClause = &sqlparser.ComparisonExpr{
			Operator: sqlparser.AST_GE,
			Left:     pk,
			Right:    sqlparser.ValArg([]byte(":" + startBindVarName)),
		}
		if start.IsNumeric() {
			v, _ := start.ParseInt64()
			bindVars[startBindVarName] = v
		} else if start.IsString() {
			bindVars[startBindVarName] = start.Raw()
		} else if start.IsFractional() {
			v, _ := start.ParseFloat64()
			bindVars[startBindVarName] = v
		}
	}
	// splitColumn < end
	if !end.IsNull() {
		endClause = &sqlparser.ComparisonExpr{
			Operator: sqlparser.AST_LT,
			Left:     pk,
			Right:    sqlparser.ValArg([]byte(":" + endBindVarName)),
		}
		if end.IsNumeric() {
			v, _ := end.ParseInt64()
			bindVars[endBindVarName] = v
		} else if end.IsString() {
			bindVars[endBindVarName] = end.Raw()
		} else if end.IsFractional() {
			v, _ := end.ParseFloat64()
			bindVars[endBindVarName] = v
		}
	}
	if startClause == nil {
		clauses = endClause
	} else {
		if endClause == nil {
			clauses = startClause
		} else {
			// splitColumn >= start AND splitColumn < end
			clauses = &sqlparser.AndExpr{
				Left:  startClause,
				Right: endClause,
			}
		}
	}
	if whereClause != nil {
		clauses = &sqlparser.AndExpr{
			Left:  &sqlparser.ParenBoolExpr{Expr: whereClause.Expr},
			Right: &sqlparser.ParenBoolExpr{Expr: clauses},
		}
	}
	return &sqlparser.Where{
		Type: sqlparser.AST_WHERE,
		Expr: clauses,
	}
}