Ejemplo n.º 1
0
func TestClusterFlow(t *testing.T) {
	defer leaktest.AfterTest(t)()
	const numRows = 100

	args := base.TestClusterArgs{ReplicationMode: base.ReplicationManual}
	tc := serverutils.StartTestCluster(t, 3, args)
	defer tc.Stopper().Stop()

	sumDigitsFn := func(row int) parser.Datum {
		sum := 0
		for row > 0 {
			sum += row % 10
			row /= 10
		}
		return parser.NewDInt(parser.DInt(sum))
	}

	sqlutils.CreateTable(t, tc.ServerConn(0), "t",
		"num INT PRIMARY KEY, digitsum INT, numstr STRING, INDEX s (digitsum)",
		numRows,
		sqlutils.ToRowFn(sqlutils.RowIdxFn, sumDigitsFn, sqlutils.RowEnglishFn))

	kvDB := tc.Server(0).KVClient().(*client.DB)
	desc := sqlbase.GetTableDescriptor(kvDB, "test", "t")
	makeIndexSpan := func(start, end int) TableReaderSpan {
		var span roachpb.Span
		prefix := roachpb.Key(sqlbase.MakeIndexKeyPrefix(desc, desc.Indexes[0].ID))
		span.Key = append(prefix, encoding.EncodeVarintAscending(nil, int64(start))...)
		span.EndKey = append(span.EndKey, prefix...)
		span.EndKey = append(span.EndKey, encoding.EncodeVarintAscending(nil, int64(end))...)
		return TableReaderSpan{Span: span}
	}

	// Set up table readers on three hosts feeding data into a join reader on
	// the third host. This is a basic test for the distributed flow
	// infrastructure, including local and remote streams.
	//
	// Note that the ranges won't necessarily be local to the table readers, but
	// that doesn't matter for the purposes of this test.

	// Start a span (useful to look at spans using Lighstep).
	sp, err := tracing.JoinOrNew(tracing.NewTracer(), nil, "cluster test")
	if err != nil {
		t.Fatal(err)
	}
	ctx := opentracing.ContextWithSpan(context.Background(), sp)
	defer sp.Finish()

	tr1 := TableReaderSpec{
		Table:         *desc,
		IndexIdx:      1,
		OutputColumns: []uint32{0, 1},
		Spans:         []TableReaderSpan{makeIndexSpan(0, 8)},
	}

	tr2 := TableReaderSpec{
		Table:         *desc,
		IndexIdx:      1,
		OutputColumns: []uint32{0, 1},
		Spans:         []TableReaderSpan{makeIndexSpan(8, 12)},
	}

	tr3 := TableReaderSpec{
		Table:         *desc,
		IndexIdx:      1,
		OutputColumns: []uint32{0, 1},
		Spans:         []TableReaderSpan{makeIndexSpan(12, 100)},
	}

	jr := JoinReaderSpec{
		Table:         *desc,
		OutputColumns: []uint32{2},
	}

	txn := client.NewTxn(ctx, *kvDB)
	fid := FlowID{uuid.MakeV4()}

	req1 := &SetupFlowRequest{Txn: txn.Proto}
	req1.Flow = FlowSpec{
		FlowID: fid,
		Processors: []ProcessorSpec{{
			Core: ProcessorCoreUnion{TableReader: &tr1},
			Output: []OutputRouterSpec{{
				Type: OutputRouterSpec_MIRROR,
				Streams: []StreamEndpointSpec{
					{StreamID: 0, Mailbox: &MailboxSpec{TargetAddr: tc.Server(2).ServingAddr()}},
				},
			}},
		}},
	}

	req2 := &SetupFlowRequest{Txn: txn.Proto}
	req2.Flow = FlowSpec{
		FlowID: fid,
		Processors: []ProcessorSpec{{
			Core: ProcessorCoreUnion{TableReader: &tr2},
			Output: []OutputRouterSpec{{
				Type: OutputRouterSpec_MIRROR,
				Streams: []StreamEndpointSpec{
					{StreamID: 1, Mailbox: &MailboxSpec{TargetAddr: tc.Server(2).ServingAddr()}},
				},
			}},
		}},
	}

	req3 := &SetupFlowRequest{Txn: txn.Proto}
	req3.Flow = FlowSpec{
		FlowID: fid,
		Processors: []ProcessorSpec{
			{
				Core: ProcessorCoreUnion{TableReader: &tr3},
				Output: []OutputRouterSpec{{
					Type: OutputRouterSpec_MIRROR,
					Streams: []StreamEndpointSpec{
						{StreamID: StreamID(2)},
					},
				}},
			},
			{
				Input: []InputSyncSpec{{
					Type:     InputSyncSpec_ORDERED,
					Ordering: Ordering{Columns: []Ordering_Column{{1, Ordering_Column_ASC}}},
					Streams: []StreamEndpointSpec{
						{StreamID: 0, Mailbox: &MailboxSpec{}},
						{StreamID: 1, Mailbox: &MailboxSpec{}},
						{StreamID: StreamID(2)},
					},
				}},
				Core: ProcessorCoreUnion{JoinReader: &jr},
				Output: []OutputRouterSpec{{
					Type:    OutputRouterSpec_MIRROR,
					Streams: []StreamEndpointSpec{{Mailbox: &MailboxSpec{SimpleResponse: true}}},
				}}},
		},
	}

	if err := SetFlowRequestTrace(ctx, req1); err != nil {
		t.Fatal(err)
	}
	if err := SetFlowRequestTrace(ctx, req2); err != nil {
		t.Fatal(err)
	}
	if err := SetFlowRequestTrace(ctx, req3); err != nil {
		t.Fatal(err)
	}

	var clients []DistSQLClient
	for i := 0; i < 3; i++ {
		s := tc.Server(i)
		conn, err := s.RPCContext().GRPCDial(s.ServingAddr())
		if err != nil {
			t.Fatal(err)
		}
		clients = append(clients, NewDistSQLClient(conn))
	}

	if log.V(1) {
		log.Infof(ctx, "Setting up flow on 0")
	}
	if resp, err := clients[0].SetupFlow(ctx, req1); err != nil {
		t.Fatal(err)
	} else if resp.Error != nil {
		t.Fatal(resp.Error)
	}

	if log.V(1) {
		log.Infof(ctx, "Setting up flow on 1")
	}
	if resp, err := clients[1].SetupFlow(ctx, req2); err != nil {
		t.Fatal(err)
	} else if resp.Error != nil {
		t.Fatal(resp.Error)
	}

	if log.V(1) {
		log.Infof(ctx, "Running flow on 2")
	}
	stream, err := clients[2].RunSimpleFlow(ctx, req3)
	if err != nil {
		t.Fatal(err)
	}

	var decoder StreamDecoder
	var rows sqlbase.EncDatumRows
	for {
		msg, err := stream.Recv()
		if err != nil {
			if err == io.EOF {
				break
			}
			t.Fatal(err)
		}
		err = decoder.AddMessage(msg)
		if err != nil {
			t.Fatal(err)
		}
		rows = testGetDecodedRows(t, &decoder, rows)
	}
	if done, trailerErr := decoder.IsDone(); !done {
		t.Fatal("stream not done")
	} else if trailerErr != nil {
		t.Fatal("error in the stream trailer:", trailerErr)
	}
	// The result should be all the numbers in string form, ordered by the
	// digit sum (and then by number).
	var results []string
	for sum := 1; sum <= 50; sum++ {
		for i := 1; i <= numRows; i++ {
			if int(*sumDigitsFn(i).(*parser.DInt)) == sum {
				results = append(results, fmt.Sprintf("['%s']", sqlutils.IntToEnglish(i)))
			}
		}
	}
	expected := strings.Join(results, " ")
	expected = "[" + expected + "]"
	if rowStr := rows.String(); rowStr != expected {
		t.Errorf("Result: %s\n Expected: %s\n", rowStr, expected)
	}
}
Ejemplo n.º 2
0
func TestDistSQLJoinAndAgg(t *testing.T) {
	defer leaktest.AfterTest(t)()

	// This test sets up a distributed join between two tables:
	//  - a NumToSquare table of size N that maps integers from 1 to n to their
	//    squares
	//  - a NumToStr table of size N^2 that maps integers to their string
	//    representations. This table is split and distributed to all the nodes.
	const n = 100
	const numNodes = 5

	tc := serverutils.StartTestCluster(t, numNodes,
		base.TestClusterArgs{
			ReplicationMode: base.ReplicationManual,
			ServerArgs: base.TestServerArgs{
				UseDatabase: "test",
			},
		})
	defer tc.Stopper().Stop()
	cdb := tc.Server(0).KVClient().(*client.DB)

	sqlutils.CreateTable(
		t, tc.ServerConn(0), "NumToSquare", "x INT PRIMARY KEY, xsquared INT",
		n,
		sqlutils.ToRowFn(sqlutils.RowIdxFn, func(row int) parser.Datum {
			return parser.NewDInt(parser.DInt(row * row))
		}),
	)

	sqlutils.CreateTable(
		t, tc.ServerConn(0), "NumToStr", "y INT PRIMARY KEY, str STRING",
		n*n,
		sqlutils.ToRowFn(sqlutils.RowIdxFn, sqlutils.RowEnglishFn),
	)
	// Split the table into multiple ranges, with each range having a single
	// replica on a certain node. This forces the query to be distributed.
	//
	// TODO(radu): this approach should be generalized into test infrastructure
	// (perhaps by adding functionality to logic tests).
	// TODO(radu): we should verify that the plan is indeed distributed as
	// intended.
	descNumToStr := sqlbase.GetTableDescriptor(cdb, "test", "NumToStr")

	// split introduces a split and moves the right range to a given node.
	split := func(val int, targetNode int) {
		pik, err := sqlbase.MakePrimaryIndexKey(descNumToStr, val)
		if err != nil {
			t.Fatal(err)
		}

		splitKey := keys.MakeRowSentinelKey(pik)
		_, rightRange, err := tc.Server(0).SplitRange(splitKey)
		if err != nil {
			t.Fatal(err)
		}
		splitKey = rightRange.StartKey.AsRawKey()
		rightRange, err = tc.AddReplicas(splitKey, tc.Target(targetNode))
		if err != nil {
			t.Fatal(err)
		}

		// This transfer is necessary to avoid waiting for the lease to expire when
		// removing the first replica.
		if err := tc.TransferRangeLease(rightRange, tc.Target(targetNode)); err != nil {
			t.Fatal(err)
		}
		if _, err := tc.RemoveReplicas(splitKey, tc.Target(0)); err != nil {
			t.Fatal(err)
		}
	}
	// split moves the right range, so we split things back to front.
	for i := numNodes - 1; i > 0; i-- {
		split(n*n/numNodes*i, i)
	}

	r := sqlutils.MakeSQLRunner(t, tc.ServerConn(0))
	r.DB.SetMaxOpenConns(1)
	r.Exec("SET DIST_SQL = ALWAYS")
	res := r.QueryStr("SELECT x, str FROM NumToSquare JOIN NumToStr ON y = xsquared")
	// Verify that res contains one entry for each integer, with the string
	// representation of its square, e.g.:
	//  [1, one]
	//  [2, two]
	//  [3, nine]
	//  [4, one-six]
	// (but not necessarily in order).
	if len(res) != n {
		t.Fatalf("expected %d rows, got %d", n, len(res))
	}
	resMap := make(map[int]string)
	for _, row := range res {
		if len(row) != 2 {
			t.Fatalf("invalid row %v", row)
		}
		n, err := strconv.Atoi(row[0])
		if err != nil {
			t.Fatalf("error parsing row %v: %s", row, err)
		}
		resMap[n] = row[1]
	}
	for i := 1; i <= n; i++ {
		if resMap[i] != sqlutils.IntToEnglish(i*i) {
			t.Errorf("invalid string for %d: %s", i, resMap[i])
		}
	}

	checkRes := func(exp int) bool {
		return len(res) == 1 && len(res[0]) == 1 && res[0][0] == strconv.Itoa(exp)
	}

	// Sum the numbers in the NumToStr table.
	res = r.QueryStr("SELECT SUM(y) FROM NumToStr")
	if exp := n * n * (n*n + 1) / 2; !checkRes(exp) {
		t.Errorf("expected [[%d]], got %s", exp, res)
	}

	// Count the rows in the NumToStr table.
	res = r.QueryStr("SELECT COUNT(*) FROM NumToStr")
	if !checkRes(n * n) {
		t.Errorf("expected [[%d]], got %s", n*n, res)
	}

	// Count how many numbers contain the digit 5.
	res = r.QueryStr("SELECT COUNT(*) FROM NumToStr WHERE str LIKE '%five%'")
	exp := 0
	for i := 1; i <= n*n; i++ {
		for x := i; x > 0; x /= 10 {
			if x%10 == 5 {
				exp++
				break
			}
		}
	}
	if !checkRes(exp) {
		t.Errorf("expected [[%d]], got %s", exp, res)
	}
}