func TestQueryExecutorTableAcl(t *testing.T) { testUtils := &testUtils{} aclName := fmt.Sprintf("simpleacl-test-%d", rand.Int63()) tableacl.Register(aclName, &simpleacl.Factory{}) tableacl.SetDefaultACL(aclName) db := setUpQueryExecutorTest() query := "select * from test_table limit 1000" expected := &mproto.QueryResult{ Fields: getTestTableFields(), RowsAffected: 0, Rows: [][]sqltypes.Value{}, } db.AddQuery(query, expected) db.AddQuery("select * from test_table where 1 != 1", &mproto.QueryResult{ Fields: getTestTableFields(), }) username := "******" callInfo := &fakeCallInfo{ remoteAddr: "1.2.3.4", username: username, } ctx := callinfo.NewContext(context.Background(), callInfo) if err := tableacl.InitFromBytes( []byte(fmt.Sprintf(`{"test_table":{"READER":"%s"}}`, username))); err != nil { t.Fatalf("unable to load tableacl config, error: %v", err) } qre, sqlQuery := newTestQueryExecutor( query, ctx, enableRowCache|enableSchemaOverrides|enableStrict) checkPlanID(t, planbuilder.PLAN_PASS_SELECT, qre.plan.PlanId) testUtils.checkEqual(t, expected, qre.Execute()) sqlQuery.disallowQueries() if err := tableacl.InitFromBytes([]byte(`{"test_table":{"READER":"superuser"}}`)); err != nil { t.Fatalf("unable to load tableacl config, error: %v", err) } // without enabling Config.StrictTableAcl qre, sqlQuery = newTestQueryExecutor( query, ctx, enableRowCache|enableSchemaOverrides|enableStrict) checkPlanID(t, planbuilder.PLAN_PASS_SELECT, qre.plan.PlanId) qre.Execute() sqlQuery.disallowQueries() // enable Config.StrictTableAcl qre, sqlQuery = newTestQueryExecutor( query, ctx, enableRowCache|enableSchemaOverrides|enableStrict|enableStrictTableAcl) defer sqlQuery.disallowQueries() checkPlanID(t, planbuilder.PLAN_PASS_SELECT, qre.plan.PlanId) defer handleAndVerifyTabletError(t, "query should fail because current user do not have read permissions", ErrFail) qre.Execute() }
func checkLoad(configData []byte, valid bool, t *testing.T) { err := tableacl.InitFromBytes(configData) if !valid && err == nil { t.Errorf("expecting parse error none returned") } if valid && err != nil { t.Errorf("unexpected load error: %v", err) } }