/
pgocrypto.go
125 lines (112 loc) · 3.49 KB
/
pgocrypto.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
/*
pgocrypto is a simple library for transferring encrypted data between a Go
program and a PostgreSQL database, using only pgcrypto in the database and Go's
standard library in the client.
*/
package pgocrypto
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"io"
"unicode/utf8"
)
// very simple PKCS padding, as implemented in pgcrypto
func pkcsPad(input []byte, blockSize int) []byte {
padLen := blockSize - (len(input) % blockSize)
padded := make([]byte, len(input)+padLen)
copy(padded, input)
padding := padded[len(input):]
for i, _ := range padding {
padding[i] = byte(padLen)
}
return padded
}
// .. and the reverse operation
func pkcsUnpad(input []byte, blockSize int) ([]byte, error) {
if len(input)%blockSize != 0 {
return nil, fmt.Errorf("input length %d not divisible by block size %d", len(input), blockSize)
}
if len(input) < blockSize {
return nil, fmt.Errorf("input length %d is smaller than block size %d", len(input), blockSize)
}
padLen := int(input[len(input)-1])
if padLen <= 0 || padLen > blockSize {
return nil, fmt.Errorf("invalid padding length %d", padLen)
}
for pos, byte := range input[len(input)-padLen:] {
if int(byte) != padLen {
return nil, fmt.Errorf("padding byte %d at pos %d is not the same as padding length %d", byte, pos, padLen)
}
}
return input[:len(input)-padLen], nil
}
// Encrypts a slice of bytes using secretKey.
func Encrypt(plaintext []byte, secretKey []byte) ([]byte, error) {
aes, err := aes.NewCipher(secretKey)
if err != nil {
return nil, err
}
iv := make([]byte, aes.BlockSize())
_, err = io.ReadFull(rand.Reader, iv)
if err != nil {
return nil, err
}
cbc := cipher.NewCBCEncrypter(aes, iv)
padded := pkcsPad(plaintext, aes.BlockSize())
// put the IV at the beginning of the ciphertext
encrypted := make([]byte, len(iv)+len(padded))
copy(encrypted[:len(iv)], iv)
cbc.CryptBlocks(encrypted[len(iv):], padded)
return encrypted, nil
}
// Encrypts a UTF-8 string using secretKey. The output will be encoded in
// base64 to support storing in the database as a "text" value.
func EncryptString(plaintext string, secretKey []byte) (string, error) {
ciphertext, err := Encrypt([]byte(plaintext), secretKey)
if err != nil {
return "", err
}
encoded := base64.StdEncoding.EncodeToString(ciphertext)
return encoded, nil
}
// Decrypts a byte slice using secretKey.
func Decrypt(ciphertext []byte, secretKey []byte) ([]byte, error) {
aes, err := aes.NewCipher(secretKey)
if err != nil {
return nil, err
}
if (len(ciphertext) % aes.BlockSize()) > 0 {
return nil, fmt.Errorf("input length %d is not a multiple of blocksize %d", len(ciphertext), aes.BlockSize())
}
iv := ciphertext[:aes.BlockSize()]
cbc := cipher.NewCBCDecrypter(aes, iv)
ciphertext = ciphertext[len(iv):]
// decrypt in-place
cbc.CryptBlocks(ciphertext, ciphertext)
unpadded, err := pkcsUnpad(ciphertext, aes.BlockSize())
if err != nil {
return nil, err
}
return unpadded, err
}
// Decrypts a base64-encoded representation of the result of encoding the bytes
// of a UTF-8 string. This is the reverse operation of EncryptString or its
// in-database equivalent.
func DecryptString(ciphertext string, secretKey []byte) (string, error) {
decoded, err := base64.StdEncoding.DecodeString(ciphertext)
if err != nil {
return "", err
}
data, err := Decrypt(decoded, secretKey)
if err != nil {
return "", err
}
if !utf8.Valid(data) {
return "", errors.New("decrypted string is not valid UTF-8 data")
}
return string(data), nil
}