func checkOneOp(in *proto.PkgOneOp, req *PkgArgs, au Authorize) bool { n, err := in.Decode(req.Pkg) if err != nil || n != len(req.Pkg) { in.ErrCode = table.EcDecodeFail } if in.ErrCode == 0 && in.DbId == proto.AdminDbId { in.ErrCode = table.EcInvDbId } if in.ErrCode == 0 && !au.IsAuth(in.DbId) { in.ErrCode = table.EcNoPrivilege } if in.ErrCode != 0 { in.CtrlFlag &^= 0xFF // Clear all ctrl flags in.CtrlFlag |= proto.CtrlErrCode } return in.ErrCode == 0 }
func (tbl *Table) Auth(req *PkgArgs, au Authorize) []byte { var in proto.PkgOneOp _, err := in.Decode(req.Pkg) if err != nil { in.CtrlFlag &^= 0xFF // Clear all ctrl flags in.SetErrCode(table.EcDecodeFail) return replyHandle(&in) } in.ErrCode = 0 var authDB uint8 var already bool if au.IsAuth(proto.AdminDbId) { authDB = proto.AdminDbId already = true } else if in.DbId != proto.AdminDbId && au.IsAuth(in.DbId) { authDB = in.DbId already = true } if already { return replyHandle(&in) } password := string(in.RowKey) tbl.mtx.Lock() // Admin password if tbl.authPwd == nil || tbl.authPwd[proto.AdminDbId] == password { authDB = proto.AdminDbId } else { // Selected DB password if len(password) > 0 && tbl.authPwd[in.DbId] == password { authDB = in.DbId } else { in.SetErrCode(table.EcAuthFailed) } } tbl.mtx.Unlock() // Success if in.ErrCode == 0 { in.DbId = authDB au.SetAuth(authDB) } return replyHandle(&in) }
func (srv *Server) replyOneOp(req *Request, errCode int8) { var out proto.PkgOneOp out.Cmd = req.Cmd out.DbId = req.DbId out.Seq = req.Seq out.ErrCode = errCode if out.ErrCode != 0 { out.CtrlFlag |= proto.CtrlErrCode } var pkg = make([]byte, out.Length()) _, err := out.Encode(pkg) if err != nil { log.Fatalf("Encode failed: %s\n", err) } srv.sendResp(false, req, pkg) }
func TestTableGetCas(t *testing.T) { var in proto.PkgOneOp in.Cmd = proto.CmdGet in.DbId = 1 in.Seq = 10 in.KeyValue = getTestKV(2, []byte("row1"), []byte("col1"), []byte("v1"), 0, 2) out := myGet(in, testAuth, getTestWA(), t) if bytes.Compare(out.Value, []byte("v1")) != 0 { t.Fatalf("Value mismatch: %q", out.Value) } if out.Score != 30 { t.Fatalf("Score mismatch") } if out.Cas == 0 { t.Fatalf("Should return new cas") } // Set in.Cmd = proto.CmdSet in.SetValue(append(out.Value, []byte("-cas")...)) in.SetScore(32) in.SetCas(out.Cas) mySet(in, testAuth, getTestWA(), true, t) // Set again should fail mySet(in, testAuth, getTestWA(), false, t) // Get in.Cmd = proto.CmdGet in.Cas = 0 in.CtrlFlag &^= 0xFF out = myGet(in, testAuth, getTestWA(), t) if bytes.Compare(out.Value, []byte("v1-cas")) != 0 { t.Fatalf("Value mismatch: %q", out.Value) } if out.Score != 32 { t.Fatalf("Score mismatch") } }