示例#1
0
func (qs *QuerySplitter) parseInt(pkMinMax *mproto.QueryResult) ([]sqltypes.Value, error) {
	boundaries := []sqltypes.Value{}
	minNumeric := sqltypes.MakeNumeric(pkMinMax.Rows[0][0].Raw())
	maxNumeric := sqltypes.MakeNumeric(pkMinMax.Rows[0][1].Raw())
	if pkMinMax.Rows[0][0].Raw()[0] == '-' {
		// signed values, use int64
		min, err := minNumeric.ParseInt64()
		if err != nil {
			return nil, err
		}
		max, err := maxNumeric.ParseInt64()
		if err != nil {
			return nil, err
		}
		interval := (max - min) / int64(qs.splitCount)
		if interval == 0 {
			return nil, err
		}
		qs.rowCount = interval
		for i := int64(1); i < int64(qs.splitCount); i++ {
			v, err := sqltypes.BuildValue(min + interval*i)
			if err != nil {
				return nil, err
			}
			boundaries = append(boundaries, v)
		}
		return boundaries, nil
	}
	// unsigned values, use uint64
	min, err := minNumeric.ParseUint64()
	if err != nil {
		return nil, err
	}
	max, err := maxNumeric.ParseUint64()
	if err != nil {
		return nil, err
	}
	interval := (max - min) / uint64(qs.splitCount)
	if interval == 0 {
		return nil, err
	}
	qs.rowCount = int64(interval)
	for i := uint64(1); i < uint64(qs.splitCount); i++ {
		v, err := sqltypes.BuildValue(min + interval*i)
		if err != nil {
			return nil, err
		}
		boundaries = append(boundaries, v)
	}
	return boundaries, nil
}
示例#2
0
func (qs *QuerySplitter) parseFloat(pkMinMax *mproto.QueryResult) ([]sqltypes.Value, error) {
	boundaries := []sqltypes.Value{}
	min, err := strconv.ParseFloat(pkMinMax.Rows[0][0].String(), 64)
	if err != nil {
		return nil, err
	}
	max, err := strconv.ParseFloat(pkMinMax.Rows[0][1].String(), 64)
	if err != nil {
		return nil, err
	}
	interval := (max - min) / float64(qs.splitCount)
	if interval == 0 {
		return nil, err
	}
	qs.rowCount = int64(interval)
	for i := 1; i < qs.splitCount; i++ {
		boundary := min + interval*float64(i)
		v, err := sqltypes.BuildValue(boundary)
		if err != nil {
			return nil, err
		}
		boundaries = append(boundaries, v)
	}
	return boundaries, nil
}
示例#3
0
func (rci *RowcacheInvalidator) handleDMLEvent(event *blproto.StreamEvent) {
	invalidations := int64(0)
	tableInfo := rci.qe.schemaInfo.GetTable(event.TableName)
	if tableInfo == nil {
		panic(NewTabletError(FAIL, "Table %s not found", event.TableName))
	}
	if tableInfo.CacheType == schema.CACHE_NONE {
		return
	}

	sqlTypeKeys := make([]sqltypes.Value, 0, len(event.PKColNames))
	for _, pkTuple := range event.PKValues {
		sqlTypeKeys = sqlTypeKeys[:0]
		for _, pkVal := range pkTuple {
			key, err := sqltypes.BuildValue(pkVal)
			if err != nil {
				log.Errorf("Error building invalidation key for %#v: '%v'", event, err)
				internalErrors.Add("Invalidation", 1)
				return
			}
			sqlTypeKeys = append(sqlTypeKeys, key)
		}
		newKey := validateKey(tableInfo, buildKey(sqlTypeKeys))
		if newKey == "" {
			continue
		}
		tableInfo.Cache.Delete(newKey)
		invalidations++
	}
	tableInfo.invalidations.Add(invalidations)
}
示例#4
0
func TestSplitQuery(t *testing.T) {
	schemaInfo := getSchemaInfo()
	query := &proto.BoundQuery{
		Sql: "select * from test_table where count > :count",
	}
	splitter := NewQuerySplitter(query, 3, schemaInfo)
	splitter.validateQuery()
	min, _ := sqltypes.BuildValue(0)
	max, _ := sqltypes.BuildValue(300)
	minField := mproto.Field{
		Name: "min",
		Type: mproto.VT_LONGLONG,
	}
	maxField := mproto.Field{
		Name: "min",
		Type: mproto.VT_LONGLONG,
	}
	fields := []mproto.Field{minField, maxField}
	row := []sqltypes.Value{min, max}
	rows := [][]sqltypes.Value{row}
	pkMinMax := &mproto.QueryResult{
		Rows:   rows,
		Fields: fields,
	}
	splits := splitter.split(pkMinMax)
	got := []string{}
	for _, split := range splits {
		if split.RowCount != 100 {
			t.Errorf("wrong RowCount, got: %v, want: %v", split.RowCount, 100)
		}
		got = append(got, split.Query.Sql)
	}
	want := []string{
		"select * from test_table where count > :count and id < 100",
		"select * from test_table where count > :count and id >= 100 and id < 200",
		"select * from test_table where count > :count and id >= 200",
	}
	if !reflect.DeepEqual(got, want) {
		t.Errorf("wrong splits, got: %v, want: %v", got, want)
	}
}
示例#5
0
func TestBuildStreamComment(t *testing.T) {
	pk1 := "pk1"
	pk2 := "pk2"
	tableInfo := createTableInfo("Table",
		map[string]string{pk1: "int", pk2: "varchar(128)", "col1": "int"},
		[]string{pk1, pk2})

	// set pk2 = 'xyz' where pk1=1 and pk2 = 'abc'
	bindVars := map[string]interface{}{}
	pk1Val, _ := sqltypes.BuildValue(1)
	pk2Val, _ := sqltypes.BuildValue("abc")
	pkValues := []interface{}{pk1Val, pk2Val}
	pkList, _ := buildValueList(&tableInfo, pkValues, bindVars)
	pk2SecVal, _ := sqltypes.BuildValue("xyz")
	secondaryPKValues := []interface{}{nil, pk2SecVal}
	secondaryList, _ := buildSecondaryList(&tableInfo, pkList, secondaryPKValues, bindVars)
	want := []byte(" /* _stream Table (pk1 pk2 ) (1 'YWJj' ) (1 'eHl6' ); */")
	got := buildStreamComment(&tableInfo, pkList, secondaryList)
	if !reflect.DeepEqual(got, want) {
		t.Errorf("case 1 failed, got %v, want %v", got, want)
	}
}
示例#6
0
func TestBuildSecondaryList(t *testing.T) {
	pk1 := "pk1"
	pk2 := "pk2"
	tableInfo := createTableInfo("Table",
		map[string]string{pk1: "int", pk2: "varchar(128)", "col1": "int"},
		[]string{pk1, pk2})

	// set pk2 = 'xyz' where pk1=1 and pk2 = 'abc'
	bindVars := map[string]interface{}{}
	pk1Val, _ := sqltypes.BuildValue(1)
	pk2Val, _ := sqltypes.BuildValue("abc")
	pkValues := []interface{}{pk1Val, pk2Val}
	pkList, _ := buildValueList(&tableInfo, pkValues, bindVars)
	pk2SecVal, _ := sqltypes.BuildValue("xyz")
	secondaryPKValues := []interface{}{nil, pk2SecVal}
	// want [[1 xyz]]
	want := [][]sqltypes.Value{
		[]sqltypes.Value{pk1Val, pk2SecVal}}
	got, _ := buildSecondaryList(&tableInfo, pkList, secondaryPKValues, bindVars)
	if !reflect.DeepEqual(got, want) {
		t.Errorf("case 1 failed, got %v, want %v", got, want)
	}
}
示例#7
0
func EncodeValue(buf *bytes.Buffer, value interface{}) error {
	switch bindVal := value.(type) {
	case nil:
		buf.WriteString("null")
	case []sqltypes.Value:
		for i := 0; i < len(bindVal); i++ {
			if i != 0 {
				buf.WriteString(", ")
			}
			if err := EncodeValue(buf, bindVal[i]); err != nil {
				return err
			}
		}
	case [][]sqltypes.Value:
		for i := 0; i < len(bindVal); i++ {
			if i != 0 {
				buf.WriteString(", ")
			}
			buf.WriteByte('(')
			if err := EncodeValue(buf, bindVal[i]); err != nil {
				return err
			}
			buf.WriteByte(')')
		}
	case []interface{}:
		buf.WriteByte('(')
		for i, v := range bindVal {
			if i != 0 {
				buf.WriteString(", ")
			}
			if err := EncodeValue(buf, v); err != nil {
				return err
			}
		}
		buf.WriteByte(')')
	case TupleEqualityList:
		if err := bindVal.Encode(buf); err != nil {
			return err
		}
	default:
		v, err := sqltypes.BuildValue(bindVal)
		if err != nil {
			return err
		}
		v.EncodeSql(buf)
	}
	return nil
}
示例#8
0
func TestGetSplitBoundaries(t *testing.T) {
	min, _ := sqltypes.BuildValue(10)
	max, _ := sqltypes.BuildValue(60)
	row := []sqltypes.Value{min, max}
	rows := [][]sqltypes.Value{row}

	minField := mproto.Field{Name: "min", Type: mproto.VT_LONGLONG}
	maxField := mproto.Field{Name: "max", Type: mproto.VT_LONGLONG}
	fields := []mproto.Field{minField, maxField}

	pkMinMax := &mproto.QueryResult{
		Fields: fields,
		Rows:   rows,
	}

	splitter := &QuerySplitter{}
	splitter.splitCount = 5
	boundaries := splitter.getSplitBoundaries(pkMinMax)
	if len(boundaries) != splitter.splitCount-1 {
		t.Errorf("wrong number of boundaries got: %v, want: %v", len(boundaries), splitter.splitCount-1)
	}
	got := splitter.getSplitBoundaries(pkMinMax)
	want := []sqltypes.Value{buildVal(20), buildVal(30), buildVal(40), buildVal(50)}
	if !reflect.DeepEqual(got, want) {
		t.Errorf("incorrect boundaries, got: %v, want: %v", got, want)
	}

	// Test negative min value
	min, _ = sqltypes.BuildValue(-100)
	max, _ = sqltypes.BuildValue(100)
	row = []sqltypes.Value{min, max}
	rows = [][]sqltypes.Value{row}
	pkMinMax.Rows = rows
	got = splitter.getSplitBoundaries(pkMinMax)
	want = []sqltypes.Value{buildVal(-60), buildVal(-20), buildVal(20), buildVal(60)}
	if !reflect.DeepEqual(got, want) {
		t.Errorf("incorrect boundaries, got: %v, want: %v", got, want)
	}

	// Test float min max
	min, _ = sqltypes.BuildValue(10.5)
	max, _ = sqltypes.BuildValue(60.5)
	row = []sqltypes.Value{min, max}
	rows = [][]sqltypes.Value{row}
	minField = mproto.Field{Name: "min", Type: mproto.VT_DOUBLE}
	maxField = mproto.Field{Name: "max", Type: mproto.VT_DOUBLE}
	fields = []mproto.Field{minField, maxField}
	pkMinMax.Rows = rows
	pkMinMax.Fields = fields
	got = splitter.getSplitBoundaries(pkMinMax)
	want = []sqltypes.Value{buildVal(20.5), buildVal(30.5), buildVal(40.5), buildVal(50.5)}
	if !reflect.DeepEqual(got, want) {
		t.Errorf("incorrect boundaries, got: %v, want: %v", got, want)
	}
}
示例#9
0
func resolveListArg(col *schema.TableColumn, key string, bindVars map[string]interface{}) ([]sqltypes.Value, error) {
	val, _, err := sqlparser.FetchBindVar(key, bindVars)
	if err != nil {
		return nil, NewTabletError(FAIL, "%v", err)
	}
	list := val.([]interface{})
	resolved := make([]sqltypes.Value, len(list))
	for i, v := range list {
		sqlval, err := sqltypes.BuildValue(v)
		if err != nil {
			return nil, NewTabletError(FAIL, "%v", err)
		}
		if err = validateValue(col, sqlval); err != nil {
			return nil, err
		}
		resolved[i] = sqlval
	}
	return resolved, nil
}
示例#10
0
func getSchemaInfo() *SchemaInfo {
	table := &schema.Table{
		Name: "test_table",
	}
	zero, _ := sqltypes.BuildValue(0)
	table.AddColumn("id", "int", zero, "")
	table.AddColumn("count", "int", zero, "")
	table.PKColumns = []int{0}

	tables := make(map[string]*TableInfo, 1)
	tables["test_table"] = &TableInfo{Table: table}

	tableNoPK := &schema.Table{
		Name: "test_table_no_pk",
	}
	tableNoPK.AddColumn("id", "int", zero, "")
	tableNoPK.PKColumns = []int{}
	tables["test_table_no_pk"] = &TableInfo{Table: tableNoPK}

	return &SchemaInfo{tables: tables}
}
示例#11
0
func resolveValue(col *schema.TableColumn, value interface{}, bindVars map[string]interface{}) (result sqltypes.Value, err error) {
	switch v := value.(type) {
	case string:
		val, _, err := sqlparser.FetchBindVar(v, bindVars)
		if err != nil {
			return result, NewTabletError(FAIL, "%v", err)
		}
		sqlval, err := sqltypes.BuildValue(val)
		if err != nil {
			return result, NewTabletError(FAIL, "%v", err)
		}
		result = sqlval
	case sqltypes.Value:
		result = v
	default:
		panic(fmt.Sprintf("incompatible value type %v", v))
	}

	if err = validateValue(col, result); err != nil {
		return result, err
	}
	return result, nil
}
示例#12
0
func TestBuildValuesList(t *testing.T) {
	tableInfo := createTableInfo("Table",
		map[string]string{"pk1": "int", "pk2": "varbinary(128)", "col1": "int"},
		[]string{"pk1", "pk2"})

	// simple PK clause. e.g. where pk1 = 1
	bindVars := map[string]interface{}{}
	pk1Val, _ := sqltypes.BuildValue(1)
	pkValues := []interface{}{pk1Val}
	// want [[1]]
	want := [][]sqltypes.Value{[]sqltypes.Value{pk1Val}}
	got, _ := buildValueList(&tableInfo, pkValues, bindVars)
	if !reflect.DeepEqual(got, want) {
		t.Errorf("got %v, want %v", got, want)
	}

	// simple PK clause with bindVars. e.g. where pk1 = :pk1
	bindVars["pk1"] = 1
	pkValues = []interface{}{":pk1"}
	// want [[1]]
	want = [][]sqltypes.Value{[]sqltypes.Value{pk1Val}}
	got, _ = buildValueList(&tableInfo, pkValues, bindVars)
	if !reflect.DeepEqual(got, want) {
		t.Errorf("got %v, want %v", got, want)
	}

	// null value
	bindVars["pk1"] = nil
	pkValues = []interface{}{":pk1"}
	// want [[1]]
	want = [][]sqltypes.Value{[]sqltypes.Value{sqltypes.Value{}}}
	got, _ = buildValueList(&tableInfo, pkValues, bindVars)
	if !reflect.DeepEqual(got, want) {
		t.Errorf("got %v, want %v", got, want)
	}

	// invalid value
	bindVars["pk1"] = struct{}{}
	pkValues = []interface{}{":pk1"}
	wantErr := "error: unsupported bind variable type struct {}: {}"

	got, err := buildValueList(&tableInfo, pkValues, bindVars)
	if err == nil || err.Error() != wantErr {
		t.Errorf("got %v, want %v", err, wantErr)
	}

	// type mismatch int
	bindVars["pk1"] = "str"
	pkValues = []interface{}{":pk1"}
	wantErr = "error: type mismatch, expecting numeric type for str"

	got, err = buildValueList(&tableInfo, pkValues, bindVars)
	if err == nil || err.Error() != wantErr {
		t.Errorf("got %v, want %v", err, wantErr)
	}

	// type mismatch binary
	bindVars["pk1"] = 1
	bindVars["pk2"] = 1
	pkValues = []interface{}{":pk1", ":pk2"}
	wantErr = "error: type mismatch, expecting string type for 1"

	got, err = buildValueList(&tableInfo, pkValues, bindVars)
	t.Logf("%v", got)
	if err == nil || err.Error() != wantErr {
		t.Errorf("got %v, want %v", err, wantErr)
	}

	// composite PK clause. e.g. where pk1 = 1 and pk2 = "abc"
	pk2Val, _ := sqltypes.BuildValue("abc")
	pkValues = []interface{}{pk1Val, pk2Val}
	// want [[1 abc]]
	want = [][]sqltypes.Value{[]sqltypes.Value{pk1Val, pk2Val}}
	got, _ = buildValueList(&tableInfo, pkValues, bindVars)
	if !reflect.DeepEqual(got, want) {
		t.Errorf("got %v, want %v", got, want)
	}

	// multi row composite PK insert
	// e.g. insert into Table(pk1,pk2) values (1, "abc"), (2, "xyz")
	pk1Val2, _ := sqltypes.BuildValue(2)
	pk2Val2, _ := sqltypes.BuildValue("xyz")
	pkValues = []interface{}{
		[]interface{}{pk1Val, pk1Val2},
		[]interface{}{pk2Val, pk2Val2},
	}
	// want [[1 abc][2 xyz]]
	want = [][]sqltypes.Value{
		[]sqltypes.Value{pk1Val, pk2Val},
		[]sqltypes.Value{pk1Val2, pk2Val2}}
	got, _ = buildValueList(&tableInfo, pkValues, bindVars)
	if !reflect.DeepEqual(got, want) {
		t.Errorf("got %v, want %v", got, want)
	}

	// composite PK IN clause
	// e.g. where pk1 = 1 and pk2 IN ("abc", "xyz")
	pkValues = []interface{}{
		pk1Val,
		[]interface{}{pk2Val, pk2Val2},
	}
	// want [[1 abc][1 xyz]]
	want = [][]sqltypes.Value{
		[]sqltypes.Value{pk1Val, pk2Val},
		[]sqltypes.Value{pk1Val, pk2Val2},
	}

	got, _ = buildValueList(&tableInfo, pkValues, bindVars)
	if !reflect.DeepEqual(got, want) {
		t.Errorf("got %v, want %v", got, want)
	}

	// list arg
	// e.g. where pk1 = 1 and pk2 IN ::list
	bindVars = map[string]interface{}{
		"list": []interface{}{
			"abc",
			"xyz",
		},
	}
	pkValues = []interface{}{
		pk1Val,
		"::list",
	}
	// want [[1 abc][1 xyz]]
	want = [][]sqltypes.Value{
		[]sqltypes.Value{pk1Val, pk2Val},
		[]sqltypes.Value{pk1Val, pk2Val2},
	}

	// list arg one value
	// e.g. where pk1 = 1 and pk2 IN ::list
	bindVars = map[string]interface{}{
		"list": []interface{}{
			"abc",
		},
	}
	pkValues = []interface{}{
		pk1Val,
		"::list",
	}
	// want [[1 abc][1 xyz]]
	want = [][]sqltypes.Value{
		[]sqltypes.Value{pk1Val, pk2Val},
	}

	got, _ = buildValueList(&tableInfo, pkValues, bindVars)
	if !reflect.DeepEqual(got, want) {
		t.Errorf("got %v, want %v", got, want)
	}

	// list arg empty list
	bindVars = map[string]interface{}{
		"list": []interface{}{},
	}
	pkValues = []interface{}{
		pk1Val,
		"::list",
	}
	wantErr = "error: empty list supplied for list"

	got, err = buildValueList(&tableInfo, pkValues, bindVars)
	if err == nil || err.Error() != wantErr {
		t.Errorf("got %v, want %v", err, wantErr)
	}

	// list arg for non-list
	bindVars = map[string]interface{}{
		"list": []interface{}{},
	}
	pkValues = []interface{}{
		pk1Val,
		":list",
	}
	wantErr = "error: unexpected arg type []interface {} for key list"

	got, err = buildValueList(&tableInfo, pkValues, bindVars)
	if err == nil || err.Error() != wantErr {
		t.Errorf("got %v, want %v", err, wantErr)
	}
}
示例#13
0
func TestGetWhereClause(t *testing.T) {
	splitter := &QuerySplitter{}
	sql := "select * from test_table where count > :count"
	statement, _ := sqlparser.Parse(sql)
	splitter.sel, _ = statement.(*sqlparser.Select)
	splitter.pkCol = "id"

	// no boundary case, start = end = nil, should not change the where clause
	nilValue := sqltypes.Value{}
	clause := splitter.getWhereClause(nilValue, nilValue)
	want := " where count > :count"
	got := sqlparser.String(clause)
	if !reflect.DeepEqual(got, want) {
		t.Errorf("incorrect where clause for nil ranges, got:%v, want:%v", got, want)
	}

	// Set lower bound, should add the lower bound condition to where clause
	start, _ := sqltypes.BuildValue(20)
	clause = splitter.getWhereClause(start, nilValue)
	want = " where count > :count and id >= 20"
	got = sqlparser.String(clause)
	if !reflect.DeepEqual(got, want) {
		t.Errorf("incorrect where clause, got:%v, want:%v", got, want)
	}

	// Set upper bound, should add the upper bound condition to where clause
	end, _ := sqltypes.BuildValue(40)
	clause = splitter.getWhereClause(nilValue, end)
	want = " where count > :count and id < 40"
	got = sqlparser.String(clause)
	if !reflect.DeepEqual(got, want) {
		t.Errorf("incorrect where clause, got:%v, want:%v", got, want)
	}

	// Set both bounds, should add two conditions to where clause
	clause = splitter.getWhereClause(start, end)
	want = " where count > :count and id >= 20 and id < 40"
	got = sqlparser.String(clause)
	if !reflect.DeepEqual(got, want) {
		t.Errorf("incorrect where clause, got:%v, want:%v", got, want)
	}

	// Original query with no where clause
	sql = "select * from test_table"
	statement, _ = sqlparser.Parse(sql)
	splitter.sel, _ = statement.(*sqlparser.Select)

	// no boundary case, start = end = nil should return no where clause
	clause = splitter.getWhereClause(nilValue, nilValue)
	want = ""
	got = sqlparser.String(clause)
	if !reflect.DeepEqual(got, want) {
		t.Errorf("incorrect where clause for nil ranges, got:%v, want:%v", got, want)
	}

	// Set both bounds, should add two conditions to where clause
	clause = splitter.getWhereClause(start, end)
	want = " where id >= 20 and id < 40"
	got = sqlparser.String(clause)
	if !reflect.DeepEqual(got, want) {
		t.Errorf("incorrect where clause, got:%v, want:%v", got, want)
	}
}
示例#14
0
func buildVal(val interface{}) sqltypes.Value {
	v, _ := sqltypes.BuildValue(val)
	return v
}