예제 #1
0
파일: runner.go 프로젝트: sywxf/hood
func main() {
	wd, err := os.Getwd()
	if err != nil {
		panic(err)
	}
	// Determine direction
	up := false
	verbose := false
	if n := len(os.Args); n > 1 {
		for i := 1; i < n; i++ {
			switch os.Args[i] {
			case "db:migrate":
				up = true
			case "-v":
				verbose = true
			}
		}
	}
	if up {
		fmt.Println("applying migrations...")
	} else {
		fmt.Println("rolling back...")
	}
	// Get up/down migration methods
	v := reflect.ValueOf(&M{})
	numMethods := v.NumMethod()
	stamps := make([]int, 0, numMethods)
	ups := make(map[int]reflect.Method)
	downs := make(map[int]reflect.Method)
	for i := 0; i < numMethods; i++ {
		method := v.Type().Method(i)
		chunks := strings.Split(method.Name, "_")
		if l := len(chunks); l >= 3 {
			ts, _ := strconv.Atoi(chunks[l-2])
			direction := chunks[l-1]
			if strings.ToLower(direction) == "up" {
				ups[ts] = method
				stamps = append(stamps, ts)
			} else {
				downs[ts] = method
			}
		}
	}
	sort.Ints(stamps)
	// Open hood
	hd, err := hood.Load(
		path.Join(wd, "db", "config.json"),
		os.Getenv("HOOD_ENV"),
	)
	if err != nil {
		panic(err)
	}
	// Check migration table
	err = hd.CreateTableIfNotExists(&Migrations{})
	if err != nil {
		panic(err)
	}
	var rows []Migrations
	err = hd.Find(&rows)
	if err != nil {
		panic(err)
	}
	info := Migrations{}
	if len(rows) > 0 {
		info = rows[0]
	}
	runCount := 0
	for i, ts := range stamps {
		if up {
			if ts > info.Current {
				tx := hd.Begin()
				tx.Log = verbose
				method := ups[ts]
				method.Func.Call([]reflect.Value{v, reflect.ValueOf(tx)})
				info.Current = ts
				tx.Save(&info)
				err = tx.Commit()
				if err != nil {
					panic(err)
				} else {
					runCount++
					fmt.Printf("applied %s\n", method.Name)
				}
			}
		} else {
			if info.Current == ts {
				tx := hd.Begin()
				tx.Log = verbose
				method := downs[ts]
				method.Func.Call([]reflect.Value{v, reflect.ValueOf(tx)})
				if i > 0 {
					info.Current = stamps[i-1]
				} else {
					info.Current = 0
				}
				tx.Save(&info)
				err = tx.Commit()
				if err != nil {
					panic(err)
				} else {
					runCount++
					fmt.Printf("rolled back %s\n", method.Name)
					break
				}
			}
		}
	}
	if up {
		fmt.Printf("applied %d migrations\n", runCount)
	} else {
		fmt.Printf("rolled back %d migrations\n", runCount)
	}
	fmt.Println("generating new schema...")
	dry := hood.Dry()
	for _, ts := range stamps {
		if ts <= info.Current {
			method := ups[ts]
			method.Func.Call([]reflect.Value{v, reflect.ValueOf(dry)})
		}
	}
	schema := fmt.Sprintf(
		"package db\n\nimport (\n\t\"github.com/eaigner/hood\"\n)\n\n%s",
		dry.SchemaDefinition(),
	)
	schemaPath := path.Join(wd, "db", "schema.go")
	err = ioutil.WriteFile(schemaPath, []byte(schema), 0666)
	if err != nil {
		panic(err)
	}
	err = exec.Command("go", "fmt", schemaPath).Run()
	if err != nil {
		panic(err)
	}
	fmt.Printf("wrote schema %s\n", schemaPath)
	fmt.Println("done.")
}
예제 #2
0
func main() {
	// Print action
	if steps > 0 {
		log.Printf("applying migrations...")
	} else if steps == -1 {
		log.Printf("rolling back by 1...")
	} else if steps < 0 {
		log.Printf("reset. rolling back all migrations...")
	}

	// Parse migrations
	stamps := []int{}
	ups := map[int]reflect.Method{}
	downs := map[int]reflect.Method{}

	structVal := reflect.ValueOf(&M{})
	for i := 0; i < structVal.NumMethod(); i++ {
		method := structVal.Type().Method(i)
		if c := strings.Split(method.Name, "_"); len(c) >= 3 {
			stamp, _ := strconv.Atoi(c[len(c)-2])
			if c[len(c)-1] == "Up" {
				ups[stamp] = method
				stamps = append(stamps, stamp)
			} else {
				downs[stamp] = method
			}
		}
	}

	sort.Ints(stamps)

	// Open hood
	hd, err := hood.Open(driver, source)
	if err != nil {
		panic(err)
	}
	hd.Log = true

	// Create migration table if necessary
	tx := hd.Begin()
	tx.CreateTableIfNotExists(&Migrations{})
	err = tx.Commit()
	if err != nil {
		panic(err)
	}

	// Check if any previous migrations have been run
	var rows []Migrations
	err = hd.Find(&rows)
	if err != nil {
		panic(err)
	}
	if len(rows) > 1 {
		panic("invalid migrations table")
	}
	info := Migrations{}
	if len(rows) > 0 {
		info = rows[0]
	}

	// Apply
	cur := 0
	count := 0
	if steps > 0 {
		for _, stamp := range stamps {
			if stamp > info.Current {
				if cur++; cur <= steps {
					apply(stamp, stamp, &count, hd, &info, structVal, ups[stamp])
				}
			}
		}
	} else if steps < 0 {
		for i := len(stamps) - 1; i >= 0; i-- {
			stamp := stamps[i]
			next := 0
			if i > 0 {
				next = stamps[i-1]
			}
			if stamp <= info.Current {
				if cur--; cur >= steps {
					apply(stamp, next, &count, hd, &info, structVal, downs[stamp])
				}
			}
		}
	}

	if steps > 0 {
		log.Printf("applied %d migrations", count)
	} else if steps < 0 {
		log.Printf("rolled back %d migrations", count)
	}

	log.Printf("generating new schema... %s", schemaPath)

	dry := hood.Dry()
	for _, ts := range stamps {
		if ts <= info.Current {
			method := ups[ts]
			method.Func.Call([]reflect.Value{structVal, reflect.ValueOf(dry)})
		}
	}
	err = ioutil.WriteFile(schemaPath, []byte(dry.GoSchema()), 0666)
	if err != nil {
		panic(err)
	}
	err = exec.Command("go", "fmt", schemaPath).Run()
	if err != nil {
		panic(err)
	}
	log.Printf("wrote schema %s", schemaPath)
	log.Printf("done.")
}