func TestGetFlagInfo(t *testing.T) { db, err := openDB() defer db.Close() flg := steward.Flag{ID: 1, Flag: "asdfasdf", Round: 5345, TeamID: 433, ServiceID: 353, Cred: "1:2"} err = steward.AddFlag(db.db, flg) new_flg, err := steward.GetFlagInfo(db.db, flg.Flag) if err != nil { log.Fatalln("Cannot get flag info:", err) } if new_flg != flg { log.Fatalln("Readed flag is not equal to writed before") } }
func TestCountRound(*testing.T) { db, err := openDB() if err != nil { log.Fatalln("Open database failed:", err) } defer db.Close() fillTestTeams(db.db) fillTestServices(db.db) priv, err := vexillary.GenerateKey() if err != nil { log.Fatalln("Generate key failed:", err) } round, err := steward.NewRound(db.db, time.Minute) if err != nil { log.Fatalln("Create new round failed:", err) } teams, err := steward.GetTeams(db.db) if err != nil { log.Fatalln("Get teams failed:", err) } services, err := steward.GetServices(db.db) if err != nil { log.Fatalln("Get services failed:", err) } flags := make([]string, 0) for _, team := range teams { for _, svc := range services { flag, err := vexillary.GenerateFlag(priv) if err != nil { log.Fatalln("Generate flag failed:", err) } flags = append(flags, flag) flg := steward.Flag{ID: -1, Flag: flag, Round: round, TeamID: team.ID, ServiceID: svc.ID, Cred: ""} err = steward.AddFlag(db.db, flg) if err != nil { log.Fatalln("Add flag to database failed:", err) } err = steward.PutStatus(db.db, steward.Status{ Round: round, TeamID: team.ID, ServiceID: svc.ID, State: steward.StatusUP}) if err != nil { log.Fatalln("Put status to database failed:", err) } } } flag1, err := steward.GetFlagInfo(db.db, flags[2]) if err != nil { log.Fatalln("Get flag info failed:", err) } err = steward.CaptureFlag(db.db, flag1.ID, teams[2].ID) if err != nil { log.Fatalln("Capture flag failed:", err) } flag2, err := steward.GetFlagInfo(db.db, flags[7]) if err != nil { log.Fatalln("Get flag info failed:", err) } err = steward.CaptureFlag(db.db, flag2.ID, teams[3].ID) if err != nil { log.Fatalln("Capture flag failed:", err) } err = counter.CountRound(db.db, round, teams, services) if err != nil { log.Fatalln("Count round failed:", err) } res, err := steward.GetRoundResult(db.db, teams[0].ID, round) if err != nil || res.AttackScore != 0.0 || res.DefenceScore != 1.75 { log.Fatalln("Invalid result:", res) } res, err = steward.GetRoundResult(db.db, teams[1].ID, round) if err != nil || res.AttackScore != 0.0 || res.DefenceScore != 1.75 { log.Fatalln("Invalid result:", res) } res, err = steward.GetRoundResult(db.db, teams[2].ID, round) if err != nil || res.AttackScore != 0.25 || res.DefenceScore != 2.0 { log.Fatalln("Invalid result:", res) } res, err = steward.GetRoundResult(db.db, teams[3].ID, round) if err != nil || res.AttackScore != 0.25 || res.DefenceScore != 2.0 { log.Fatalln("Invalid result:", res) } }
func handler(conn net.Conn, db *sql.DB, priv *rsa.PrivateKey, attackFlow chan scoreboard.Attack) { addr := conn.RemoteAddr().String() defer conn.Close() fmt.Fprint(conn, greetingMsg) flag, err := bufio.NewReader(conn).ReadString('\n') if err != nil { log.Println("Read error:", err) } flag = strings.Trim(flag, "\n") log.Printf("\tGet flag %s from %s", flag, addr) valid, err := vexillary.ValidFlag(flag, priv.PublicKey) if err != nil { log.Println("\tValidate flag failed:", err) } if !valid { fmt.Fprint(conn, invalidFlagMsg) return } exist, err := steward.FlagExist(db, flag) if err != nil { log.Println("\tExist flag check failed:", err) fmt.Fprint(conn, internalErrorMsg) return } if !exist { fmt.Fprint(conn, flagDoesNotExistMsg) return } flg, err := steward.GetFlagInfo(db, flag) if err != nil { log.Println("\tGet flag info failed:", err) fmt.Fprint(conn, internalErrorMsg) return } captured, err := steward.AlreadyCaptured(db, flg.ID) if err != nil { log.Println("\tAlready captured check failed:", err) fmt.Fprint(conn, internalErrorMsg) return } if captured { fmt.Fprint(conn, alreadyCapturedMsg) return } team, err := teamByAddr(db, addr) if err != nil { log.Println("\tGet team by ip failed:", err) fmt.Fprint(conn, invalidTeamMsg) return } if flg.TeamID == team.ID { log.Printf("\tTeam %s try to send their flag", team.Name) fmt.Fprint(conn, flagYoursMsg) return } round, err := steward.CurrentRound(db) if round.ID != flg.Round { log.Printf("\t%s try to send flag from past round", team.Name) fmt.Fprint(conn, flagExpiredMsg) return } roundEndTime := round.StartTime.Add(round.Len) if time.Now().After(roundEndTime) { log.Printf("\t%s try to send flag from finished round", team.Name) fmt.Fprint(conn, flagExpiredMsg) return } halfStatus := steward.Status{flg.Round, team.ID, flg.ServiceID, steward.StatusUnknown} state, err := steward.GetState(db, halfStatus) if state != steward.StatusUP { log.Printf("\t%s service not ok, cannot capture", team.Name) fmt.Fprint(conn, serviceNotUpMsg) return } err = steward.CaptureFlag(db, flg.ID, team.ID) if err != nil { log.Println("\tCapture flag failed:", err) fmt.Fprint(conn, internalErrorMsg) return } go func() { attack := scoreboard.Attack{ Attacker: team.ID, Victim: flg.TeamID, Service: flg.ServiceID, Timestamp: time.Now().Unix(), } select { case attackFlow <- attack: default: _ = <-attackFlow attackFlow <- attack } }() fmt.Fprint(conn, capturedMsg) }