/
encode.go
129 lines (106 loc) · 2.73 KB
/
encode.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
package saml
import (
"bytes"
"compress/flate"
"encoding/base64"
"io"
"sync"
"github.com/lestrrat/go-libxml2/parser"
"github.com/lestrrat/go-pdebug"
"github.com/lestrrat/go-xmlsec/crypto"
"github.com/lestrrat/go-xmlsec/dsig"
)
var b64enc = base64.StdEncoding
var flateWriterPool = sync.Pool{
New: allocFlateWriter,
}
// wasteful, but oh well
var emptyBuffer = &bytes.Buffer{}
func allocFlateWriter() interface{} {
// flate.NewWriter (as of this writing) only returns an error
// if the second argument is invalid. As we are using a standard
// compression level here, there is no way this can err
w, _ := flate.NewWriter(emptyBuffer, flate.DefaultCompression)
return w
}
func getFlateWriter() *flate.Writer {
return flateWriterPool.Get().(*flate.Writer)
}
func releaseFlateWriter(r *flate.Writer) {
r.Reset(emptyBuffer) // release the previous io.Writer
flateWriterPool.Put(r)
}
type serializer interface {
Serialize() (string, error)
}
func encode(s serializer, key *crypto.Key, compress bool) ([]byte, error) {
xmlstr, err := s.Serialize()
if err != nil {
return nil, err
}
if pdebug.Enabled {
pdebug.Printf("Generated %d bytes of XML", len(xmlstr))
}
if key != nil {
p := parser.New(parser.XMLParseDTDLoad | parser.XMLParseDTDAttr | parser.XMLParseNoEnt)
doc, err := p.ParseString(xmlstr)
if err != nil {
return nil, err
}
root, err := doc.DocumentElement()
if err != nil {
return nil, err
}
// Create a new signature section.
sig, err := dsig.NewSignature(root, dsig.ExclC14N, dsig.RsaSha1, "")
if err := sig.AddReference(dsig.Sha1, "", "", ""); err != nil {
return nil, err
}
if err := sig.AddTransform(dsig.Enveloped); err != nil {
return nil, err
}
if key.HasRsaKey() == nil || key.HasDsaKey() == nil || key.HasEcdsaKey() == nil {
if err := sig.AddKeyValue(); err != nil {
return nil, err
}
}
// If the key is setup using X509, add that node
if key.HasX509() == nil {
if err := sig.AddX509Data(); err != nil {
return nil, err
}
}
if pdebug.Enabled {
pdebug.Printf("Signing using key %p", key)
}
if err := sig.Sign(key); err != nil {
return nil, err
}
xmlstr = doc.Dump(false)
if err != nil {
return nil, err
}
}
if !compress {
return []byte(xmlstr), nil
}
buf := bytes.Buffer{}
w := getFlateWriter()
defer releaseFlateWriter(w)
w.Reset(&buf)
if _, err := io.WriteString(w, xmlstr); err != nil {
return nil, err
}
if err := w.Close(); err != nil {
return nil, err
}
if pdebug.Enabled {
pdebug.Printf("Compressed to %d bytes", buf.Len())
}
ret := make([]byte, b64enc.EncodedLen(buf.Len()))
b64enc.Encode(ret, buf.Bytes())
if pdebug.Enabled {
pdebug.Printf("Encoded into %d bytes of base64", len(ret))
}
return ret, nil
}