/
db.go
84 lines (71 loc) · 1.64 KB
/
db.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
package main
import (
"fmt"
"log"
"strings"
_ "github.com/go-sql-driver/mysql"
"github.com/guotie/config"
"github.com/guotie/deferinit"
"github.com/jinzhu/gorm"
)
var (
db *gorm.DB
)
func init() {
deferinit.AddInit(connectDatabases, nil, 1000)
}
// 连接数据库
// 连接redis
func connectDatabases() {
var err error
db, err = opendb("justTalk", "", "")
if err != nil {
panic(err)
}
db.AutoMigrate(&User{}, &Post{}, &Taxonomy{}, &TermRelation{})
//db.LogMode(true)
}
// 建立数据库连接
func opendb(dbname, dbuser, dbpass string) (*gorm.DB, error) {
var (
dbtype, dsn string
db gorm.DB
err error
)
if dbuser == "" {
dbuser = config.GetStringDefault("dbuser", "root")
}
if dbpass == "" {
dbpass = config.GetStringDefault("dbpass", "root")
}
dbtype = strings.ToLower(config.GetStringDefault("dbtype", "mysql"))
if dbtype == "mysql" {
dsn = fmt.Sprintf("%s:%s@%s(%s:%d)/%s?charset=utf8&parseTime=True&loc=Local",
dbuser,
dbpass,
config.GetStringDefault("dbproto", "tcp"),
config.GetStringDefault("dbhost", "localhost"),
config.GetIntDefault("dbport", 3306),
dbname,
)
} else if dbtype == "pg" || dbtype == "postgres" || dbtype == "postgresql" {
dbtype = "postgres"
dsn = fmt.Sprintf("user=%s password=%s host=%s port=%d dbname=%s sslmode=disable",
dbuser,
dbpass,
config.GetStringDefault("dbhost", "127.0.0.1"),
config.GetIntDefault("dbport", 5432),
dbname)
}
db, err = gorm.Open(dbtype, dsn)
if err != nil {
log.Println(err.Error())
return &db, err
}
err = db.DB().Ping()
if err != nil {
log.Println(err.Error())
return &db, err
}
return &db, nil
}