func makeSpan(tableDesc *sqlbase.TableDescriptor, i, j int) roachpb.Span {
	makeKey := func(val int) roachpb.Key {
		key, err := sqlbase.MakePrimaryIndexKey(tableDesc, val)
		if err != nil {
			panic(err)
		}
		return key
	}
	return roachpb.Span{
		Key:    makeKey(i),
		EndKey: makeKey(j),
	}
}
// splitRangeAtVal splits the range for a table with schema
// `CREATE TABLE test (k INT PRIMARY KEY)` at row with value pk (the row will be
// the first on the right of the split).
func splitRangeAtVal(
	ts *server.TestServer, tableDesc *sqlbase.TableDescriptor, pk int,
) (roachpb.RangeDescriptor, roachpb.RangeDescriptor, error) {
	if len(tableDesc.Indexes) != 0 {
		return roachpb.RangeDescriptor{}, roachpb.RangeDescriptor{},
			errors.Errorf("expected table with just a PK, got: %+v", tableDesc)
	}
	pik, err := sqlbase.MakePrimaryIndexKey(tableDesc, pk)
	if err != nil {
		return roachpb.RangeDescriptor{}, roachpb.RangeDescriptor{}, err
	}

	startKey := keys.MakeRowSentinelKey(pik)
	leftRange, rightRange, err := ts.SplitRange(startKey)
	if err != nil {
		return roachpb.RangeDescriptor{}, roachpb.RangeDescriptor{},
			errors.Wrapf(err, "failed to split at row: %d", pk)
	}
	return leftRange, rightRange, nil
}
예제 #3
0
func TestDistSQLJoinAndAgg(t *testing.T) {
	defer leaktest.AfterTest(t)()

	// This test sets up a distributed join between two tables:
	//  - a NumToSquare table of size N that maps integers from 1 to n to their
	//    squares
	//  - a NumToStr table of size N^2 that maps integers to their string
	//    representations. This table is split and distributed to all the nodes.
	const n = 100
	const numNodes = 5

	tc := serverutils.StartTestCluster(t, numNodes,
		base.TestClusterArgs{
			ReplicationMode: base.ReplicationManual,
			ServerArgs: base.TestServerArgs{
				UseDatabase: "test",
			},
		})
	defer tc.Stopper().Stop()
	cdb := tc.Server(0).KVClient().(*client.DB)

	sqlutils.CreateTable(
		t, tc.ServerConn(0), "NumToSquare", "x INT PRIMARY KEY, xsquared INT",
		n,
		sqlutils.ToRowFn(sqlutils.RowIdxFn, func(row int) parser.Datum {
			return parser.NewDInt(parser.DInt(row * row))
		}),
	)

	sqlutils.CreateTable(
		t, tc.ServerConn(0), "NumToStr", "y INT PRIMARY KEY, str STRING",
		n*n,
		sqlutils.ToRowFn(sqlutils.RowIdxFn, sqlutils.RowEnglishFn),
	)
	// Split the table into multiple ranges, with each range having a single
	// replica on a certain node. This forces the query to be distributed.
	//
	// TODO(radu): this approach should be generalized into test infrastructure
	// (perhaps by adding functionality to logic tests).
	// TODO(radu): we should verify that the plan is indeed distributed as
	// intended.
	descNumToStr := sqlbase.GetTableDescriptor(cdb, "test", "NumToStr")

	// split introduces a split and moves the right range to a given node.
	split := func(val int, targetNode int) {
		pik, err := sqlbase.MakePrimaryIndexKey(descNumToStr, val)
		if err != nil {
			t.Fatal(err)
		}

		splitKey := keys.MakeRowSentinelKey(pik)
		_, rightRange, err := tc.Server(0).SplitRange(splitKey)
		if err != nil {
			t.Fatal(err)
		}
		splitKey = rightRange.StartKey.AsRawKey()
		rightRange, err = tc.AddReplicas(splitKey, tc.Target(targetNode))
		if err != nil {
			t.Fatal(err)
		}

		// This transfer is necessary to avoid waiting for the lease to expire when
		// removing the first replica.
		if err := tc.TransferRangeLease(rightRange, tc.Target(targetNode)); err != nil {
			t.Fatal(err)
		}
		if _, err := tc.RemoveReplicas(splitKey, tc.Target(0)); err != nil {
			t.Fatal(err)
		}
	}
	// split moves the right range, so we split things back to front.
	for i := numNodes - 1; i > 0; i-- {
		split(n*n/numNodes*i, i)
	}

	r := sqlutils.MakeSQLRunner(t, tc.ServerConn(0))
	r.DB.SetMaxOpenConns(1)
	r.Exec("SET DIST_SQL = ALWAYS")
	res := r.QueryStr("SELECT x, str FROM NumToSquare JOIN NumToStr ON y = xsquared")
	// Verify that res contains one entry for each integer, with the string
	// representation of its square, e.g.:
	//  [1, one]
	//  [2, two]
	//  [3, nine]
	//  [4, one-six]
	// (but not necessarily in order).
	if len(res) != n {
		t.Fatalf("expected %d rows, got %d", n, len(res))
	}
	resMap := make(map[int]string)
	for _, row := range res {
		if len(row) != 2 {
			t.Fatalf("invalid row %v", row)
		}
		n, err := strconv.Atoi(row[0])
		if err != nil {
			t.Fatalf("error parsing row %v: %s", row, err)
		}
		resMap[n] = row[1]
	}
	for i := 1; i <= n; i++ {
		if resMap[i] != sqlutils.IntToEnglish(i*i) {
			t.Errorf("invalid string for %d: %s", i, resMap[i])
		}
	}

	checkRes := func(exp int) bool {
		return len(res) == 1 && len(res[0]) == 1 && res[0][0] == strconv.Itoa(exp)
	}

	// Sum the numbers in the NumToStr table.
	res = r.QueryStr("SELECT SUM(y) FROM NumToStr")
	if exp := n * n * (n*n + 1) / 2; !checkRes(exp) {
		t.Errorf("expected [[%d]], got %s", exp, res)
	}

	// Count the rows in the NumToStr table.
	res = r.QueryStr("SELECT COUNT(*) FROM NumToStr")
	if !checkRes(n * n) {
		t.Errorf("expected [[%d]], got %s", n*n, res)
	}

	// Count how many numbers contain the digit 5.
	res = r.QueryStr("SELECT COUNT(*) FROM NumToStr WHERE str LIKE '%five%'")
	exp := 0
	for i := 1; i <= n*n; i++ {
		for x := i; x > 0; x /= 10 {
			if x%10 == 5 {
				exp++
				break
			}
		}
	}
	if !checkRes(exp) {
		t.Errorf("expected [[%d]], got %s", exp, res)
	}
}
예제 #4
0
// checkDistAggregationInfo tests that a flow with multiple local stages and a
// final stage (in accordance with per DistAggregationInfo) gets the same result
// with a naive aggregation flow that has a single non-distributed stage.
//
// Both types of flows are set up and ran against the first numRows of the given
// table. We assume the table's first column is the primary key, with values
// from 1 to numRows. A non-PK column that works with the function is chosen.
func checkDistAggregationInfo(
	t *testing.T,
	srv serverutils.TestServerInterface,
	tableDesc *sqlbase.TableDescriptor,
	colIdx int,
	numRows int,
	fn distsqlrun.AggregatorSpec_Func,
	info DistAggregationInfo,
) {
	colType := tableDesc.Columns[colIdx].Type

	makeTableReader := func(startPK, endPK int, streamID int) distsqlrun.ProcessorSpec {
		tr := distsqlrun.TableReaderSpec{
			Table:         *tableDesc,
			OutputColumns: []uint32{uint32(colIdx)},
			Spans:         make([]distsqlrun.TableReaderSpan, 1),
		}

		var err error
		tr.Spans[0].Span.Key, err = sqlbase.MakePrimaryIndexKey(tableDesc, startPK)
		if err != nil {
			t.Fatal(err)
		}
		tr.Spans[0].Span.EndKey, err = sqlbase.MakePrimaryIndexKey(tableDesc, endPK)
		if err != nil {
			t.Fatal(err)
		}

		return distsqlrun.ProcessorSpec{
			Core: distsqlrun.ProcessorCoreUnion{TableReader: &tr},
			Output: []distsqlrun.OutputRouterSpec{{
				Type: distsqlrun.OutputRouterSpec_PASS_THROUGH,
				Streams: []distsqlrun.StreamEndpointSpec{
					{Type: distsqlrun.StreamEndpointSpec_LOCAL, StreamID: distsqlrun.StreamID(streamID)},
				},
			}},
		}
	}

	// First run a flow that aggregates all the rows without any local stages.

	rowsNonDist := runTestFlow(
		t, srv,
		makeTableReader(1, numRows+1, 0),
		distsqlrun.ProcessorSpec{
			Input: []distsqlrun.InputSyncSpec{{
				Type:        distsqlrun.InputSyncSpec_UNORDERED,
				ColumnTypes: []sqlbase.ColumnType{colType},
				Streams: []distsqlrun.StreamEndpointSpec{
					{Type: distsqlrun.StreamEndpointSpec_LOCAL, StreamID: 0},
				},
			}},
			Core: distsqlrun.ProcessorCoreUnion{Aggregator: &distsqlrun.AggregatorSpec{
				Aggregations: []distsqlrun.AggregatorSpec_Aggregation{{Func: fn, ColIdx: 0}},
			}},
			Output: []distsqlrun.OutputRouterSpec{{
				Type: distsqlrun.OutputRouterSpec_PASS_THROUGH,
				Streams: []distsqlrun.StreamEndpointSpec{
					{Type: distsqlrun.StreamEndpointSpec_SYNC_RESPONSE},
				},
			}},
		},
	)

	// Now run a flow with 4 separate table readers, each with its own local
	// stage, all feeding into a single final stage.

	numParallel := 4
	// The type outputted by the local stage can be different than the input type
	// (e.g. DECIMAL instead of INT).
	_, intermediaryType, err := distsqlrun.GetAggregateInfo(fn, colType)
	if err != nil {
		t.Fatal(err)
	}

	if numParallel < numRows {
		numParallel = numRows
	}
	finalProc := distsqlrun.ProcessorSpec{
		Input: []distsqlrun.InputSyncSpec{{
			Type:        distsqlrun.InputSyncSpec_UNORDERED,
			ColumnTypes: []sqlbase.ColumnType{intermediaryType},
		}},
		Core: distsqlrun.ProcessorCoreUnion{Aggregator: &distsqlrun.AggregatorSpec{
			Aggregations: []distsqlrun.AggregatorSpec_Aggregation{{Func: info.FinalStage, ColIdx: 0}},
		}},
		Output: []distsqlrun.OutputRouterSpec{{
			Type: distsqlrun.OutputRouterSpec_PASS_THROUGH,
			Streams: []distsqlrun.StreamEndpointSpec{
				{Type: distsqlrun.StreamEndpointSpec_SYNC_RESPONSE},
			},
		}},
	}
	var procs []distsqlrun.ProcessorSpec
	for i := 0; i < numParallel; i++ {
		tr := makeTableReader(1+i*numRows/numParallel, 1+(i+1)*numRows/numParallel, 2*i)
		agg := distsqlrun.ProcessorSpec{
			Input: []distsqlrun.InputSyncSpec{{
				Type:        distsqlrun.InputSyncSpec_UNORDERED,
				ColumnTypes: []sqlbase.ColumnType{colType},
				Streams: []distsqlrun.StreamEndpointSpec{
					{Type: distsqlrun.StreamEndpointSpec_LOCAL, StreamID: distsqlrun.StreamID(2 * i)},
				},
			}},
			Core: distsqlrun.ProcessorCoreUnion{Aggregator: &distsqlrun.AggregatorSpec{
				Aggregations: []distsqlrun.AggregatorSpec_Aggregation{{Func: info.LocalStage, ColIdx: 0}},
			}},
			Output: []distsqlrun.OutputRouterSpec{{
				Type: distsqlrun.OutputRouterSpec_PASS_THROUGH,
				Streams: []distsqlrun.StreamEndpointSpec{
					{Type: distsqlrun.StreamEndpointSpec_LOCAL, StreamID: distsqlrun.StreamID(2*i + 1)},
				},
			}},
		}
		procs = append(procs, tr, agg)
		finalProc.Input[0].Streams = append(finalProc.Input[0].Streams, distsqlrun.StreamEndpointSpec{
			Type:     distsqlrun.StreamEndpointSpec_LOCAL,
			StreamID: distsqlrun.StreamID(2*i + 1),
		})
	}
	procs = append(procs, finalProc)
	rowsDist := runTestFlow(t, srv, procs...)

	if len(rowsDist[0]) != len(rowsNonDist[0]) {
		t.Errorf("different row lengths (dist: %d non-dist: %d)", len(rowsDist[0]), len(rowsNonDist[0]))
	} else {
		for i := range rowsDist[0] {
			tDist := rowsDist[0][i].Type.String()
			tNonDist := rowsNonDist[0][i].Type.String()
			if tDist != tNonDist {
				t.Errorf("different type for column %d (dist: %s non-dist: %s)", i, tDist, tNonDist)
			}
		}
	}
	if rowsDist.String() != rowsNonDist.String() {
		t.Errorf("different results\nw/o local stage:   %s\nwith local stage:  %s", rowsNonDist, rowsDist)
	}
}