forked from fiorix/go-diameter
/
cer.go
127 lines (121 loc) · 3.75 KB
/
cer.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
// Copyright 2013-2015 go-diameter authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package sm
import (
"fmt"
"net"
"github.com/fiorix/go-diameter/diam"
"github.com/fiorix/go-diameter/diam/avp"
"github.com/fiorix/go-diameter/diam/datatype"
"github.com/fiorix/go-diameter/diam/sm/smparser"
"github.com/fiorix/go-diameter/diam/sm/smpeer"
)
// handleCER handles Capabilities-Exchange-Request messages.
//
// If mandatory AVPs such as Origin-Host, Origin-Realm, or
// Origin-State-Id are missing, we close the connection.
//
// See RFC 6733 section 5.3 for details.
func handleCER(sm *StateMachine) diam.HandlerFunc {
return func(c diam.Conn, m *diam.Message) {
ctx := c.Context()
if _, ok := smpeer.FromContext(ctx); ok {
// Ignore retransmission.
return
}
cer := new(smparser.CER)
failedAVP, err := cer.Parse(m)
if err != nil {
if failedAVP != nil {
err = errorCEA(sm, c, m, cer, failedAVP)
if err != nil {
sm.Error(&diam.ErrorReport{
Conn: c,
Message: m,
Error: err,
})
}
}
c.Close()
return
}
err = successCEA(sm, c, m, cer)
if err != nil {
sm.Error(&diam.ErrorReport{
Conn: c,
Message: m,
Error: err,
})
return
}
meta := smpeer.FromCER(cer)
c.SetContext(smpeer.NewContext(ctx, meta))
// Notify about peer passing the handshake.
select {
case sm.hsNotifyc <- c:
default:
}
}
}
// errorCEA sends an error answer indicating that the CER failed due to
// an unsupported (acct/auth) application, and includes the AVP that
// caused the failure in the message.
func errorCEA(sm *StateMachine, c diam.Conn, m *diam.Message, cer *smparser.CER, failedAVP *diam.AVP) error {
hostIP, _, err := net.SplitHostPort(c.LocalAddr().String())
if err != nil {
return fmt.Errorf("failed to parse own ip %q: %s", c.LocalAddr(), err)
}
var a *diam.Message
if failedAVP == cer.InbandSecurityID {
a = m.Answer(diam.NoCommonSecurity)
} else {
a = m.Answer(diam.NoCommonApplication)
}
a.Header.CommandFlags |= diam.ErrorFlag
a.NewAVP(avp.OriginHost, avp.Mbit, 0, sm.cfg.OriginHost)
a.NewAVP(avp.OriginRealm, avp.Mbit, 0, sm.cfg.OriginRealm)
a.NewAVP(avp.HostIPAddress, avp.Mbit, 0, datatype.Address(net.ParseIP(hostIP)))
a.NewAVP(avp.VendorID, avp.Mbit, 0, sm.cfg.VendorID)
a.NewAVP(avp.ProductName, 0, 0, sm.cfg.ProductName)
a.AddAVP(cer.OriginStateID)
a.NewAVP(avp.FailedAVP, avp.Mbit, 0, &diam.GroupedAVP{
AVP: []*diam.AVP{failedAVP},
})
a.NewAVP(avp.FirmwareRevision, avp.Mbit, 0, sm.cfg.FirmwareRevision)
_, err = a.WriteTo(c)
return err
}
// successCEA sends a success answer indicating that the CER was successfuly
// parsed and accepted by the server.
func successCEA(sm *StateMachine, c diam.Conn, m *diam.Message, cer *smparser.CER) error {
hostIP, _, err := net.SplitHostPort(c.LocalAddr().String())
if err != nil {
return fmt.Errorf("failed to parse own ip %q: %s", c.LocalAddr(), err)
}
a := m.Answer(diam.Success)
a.NewAVP(avp.OriginHost, avp.Mbit, 0, sm.cfg.OriginHost)
a.NewAVP(avp.OriginRealm, avp.Mbit, 0, sm.cfg.OriginRealm)
a.NewAVP(avp.HostIPAddress, avp.Mbit, 0, datatype.Address(net.ParseIP(hostIP)))
a.NewAVP(avp.VendorID, avp.Mbit, 0, sm.cfg.VendorID)
a.NewAVP(avp.ProductName, 0, 0, sm.cfg.ProductName)
a.AddAVP(cer.OriginStateID)
if cer.AcctApplicationID != nil {
for _, acct := range cer.AcctApplicationID {
a.AddAVP(acct)
}
}
if cer.AuthApplicationID != nil {
for _, auth := range cer.AuthApplicationID {
a.AddAVP(auth)
}
}
if cer.VendorSpecificApplicationID != nil {
for _, vs := range cer.VendorSpecificApplicationID {
a.AddAVP(vs)
}
}
a.NewAVP(avp.FirmwareRevision, avp.Mbit, 0, sm.cfg.FirmwareRevision)
_, err = a.WriteTo(c)
return err
}