func testResultCode(m *diam.Message, want uint32) bool { rc, err := m.FindAVP("Result-Code") if err != nil { return false } if code, ok := rc.Data.(datatype.Unsigned32); ok { return uint32(code) == want } return false }
// successCEA sends a success answer indicating that the CER was successfuly // parsed and accepted by the server. func successCEA(sm *StateMachine, c diam.Conn, m *diam.Message, cer *smparser.CER) error { hostIP, _, err := net.SplitHostPort(c.LocalAddr().String()) if err != nil { return fmt.Errorf("failed to parse own ip %q: %s", c.LocalAddr(), err) } a := m.Answer(diam.Success) a.NewAVP(avp.OriginHost, avp.Mbit, 0, sm.cfg.OriginHost) a.NewAVP(avp.OriginRealm, avp.Mbit, 0, sm.cfg.OriginRealm) a.NewAVP(avp.HostIPAddress, avp.Mbit, 0, datatype.Address(net.ParseIP(hostIP))) a.NewAVP(avp.VendorID, avp.Mbit, 0, sm.cfg.VendorID) a.NewAVP(avp.ProductName, 0, 0, sm.cfg.ProductName) if cer.OriginStateID != nil { a.AddAVP(cer.OriginStateID) } if cer.AcctApplicationID != nil { for _, acct := range cer.AcctApplicationID { a.AddAVP(acct) } } if cer.AuthApplicationID != nil { for _, auth := range cer.AuthApplicationID { a.AddAVP(auth) } } if cer.VendorSpecificApplicationID != nil { for _, vs := range cer.VendorSpecificApplicationID { a.AddAVP(vs) } } if sm.cfg.FirmwareRevision != 0 { a.NewAVP(avp.FirmwareRevision, avp.Mbit, 0, sm.cfg.FirmwareRevision) } _, err = a.WriteTo(c) return err }
// debitInterval is the configured debitInterval, in sync with the diameter client one func NewCCRFromDiameterMessage(m *diam.Message, debitInterval time.Duration) (*CCR, error) { var ccr CCR if err := m.Unmarshal(&ccr); err != nil { return nil, err } ccr.diamMessage = m ccr.debitInterval = debitInterval return &ccr, nil }
// Parse parses and validates the given message, and returns nil when // all AVPs are ok. func (dwr *DWR) Parse(m *diam.Message) error { err := m.Unmarshal(dwr) if err != nil { return nil } if err = dwr.sanityCheck(); err != nil { return err } return nil }
// Parse parses and validates the given message. func (cea *CEA) Parse(m *diam.Message) (err error) { if err = m.Unmarshal(cea); err != nil { return err } if err = cea.sanityCheck(); err != nil { return err } app := &Application{ AcctApplicationID: cea.AcctApplicationID, AuthApplicationID: cea.AuthApplicationID, VendorSpecificApplicationID: cea.VendorSpecificApplicationID, } if _, err := app.Parse(m.Dictionary()); err != nil { return err } cea.appID = app.ID() return nil }
// messageAddAVPsWithPath will dynamically add AVPs into the message // append: append to the message, on false overwrite if AVP is single or add to group if AVP is Grouped func messageSetAVPsWithPath(m *diam.Message, path []interface{}, avpValStr string, appnd bool, timezone string) error { if len(path) == 0 { return errors.New("Empty path as AVP filter") } dictAVPs := make([]*dict.AVP, len(path)) // for each subpath, one dictionary AVP for i, subpath := range path { if dictAVP, err := m.Dictionary().FindAVP(m.Header.ApplicationID, subpath); err != nil { return err } else if dictAVP == nil { return fmt.Errorf("Cannot find AVP with id: %s", path[len(path)-1]) } else { dictAVPs[i] = dictAVP } } if dictAVPs[len(path)-1].Data.Type == diam.GroupedAVPType { return errors.New("Last AVP in path needs not to be GroupedAVP") } var msgAVP *diam.AVP // Keep a reference here towards last AVP lastAVPIdx := len(path) - 1 for i := lastAVPIdx; i >= 0; i-- { var typeVal datatype.Type if i == lastAVPIdx { avpValByte, err := serializeAVPValueFromString(dictAVPs[i], avpValStr, timezone) if err != nil { return err } typeVal, err = datatype.Decode(dictAVPs[i].Data.Type, avpValByte) if err != nil { return err } } else { typeVal = &diam.GroupedAVP{ AVP: []*diam.AVP{msgAVP}} } newMsgAVP := diam.NewAVP(dictAVPs[i].Code, avp.Mbit, dictAVPs[i].VendorID, typeVal) // FixMe: maybe Mbit with dictionary one if i == lastAVPIdx-1 && !appnd { // last AVP needs to be appended in group avps, _ := m.FindAVPsWithPath(path[:lastAVPIdx], dict.UndefinedVendorID) if len(avps) != 0 { // Group AVP already in the message prevGrpData := avps[0].Data.(*diam.GroupedAVP) prevGrpData.AVP = append(prevGrpData.AVP, msgAVP) m.Header.MessageLength += uint32(msgAVP.Len()) return nil } } msgAVP = newMsgAVP } if !appnd { // Not group AVP, replace the previous set one with this one avps, _ := m.FindAVPsWithPath(path, dict.UndefinedVendorID) if len(avps) != 0 { // Group AVP already in the message m.Header.MessageLength -= uint32(avps[0].Len()) // decrease message length since we overwrite *avps[0] = *msgAVP m.Header.MessageLength += uint32(msgAVP.Len()) return nil } } m.AVP = append(m.AVP, msgAVP) m.Header.MessageLength += uint32(msgAVP.Len()) return nil }
// Parse parses and validates the given message, and returns nil when // all AVPs are ok, and all accounting or authentication applications // in the CER match the applications in our dictionary. If one or more // mandatory AVPs are missing, it returns a nil failedAVP and a proper // error. If all mandatory AVPs are present but no common application // is found, then it returns the failedAVP (with the application that // we don't support in our dictionary) and an error. Another cause // for error is the presence of Inband Security, we don't support that. func (cer *CER) Parse(m *diam.Message) (failedAVP *diam.AVP, err error) { if err = m.Unmarshal(cer); err != nil { return nil, err } if err = cer.sanityCheck(); err != nil { return nil, err } if cer.InbandSecurityID != nil { if v := cer.InbandSecurityID.Data.(datatype.Unsigned32); v != 0 { return cer.InbandSecurityID, ErrNoCommonSecurity } } app := &Application{ AcctApplicationID: cer.AcctApplicationID, AuthApplicationID: cer.AuthApplicationID, VendorSpecificApplicationID: cer.VendorSpecificApplicationID, } if failedAVP, err = app.Parse(m.Dictionary()); err != nil { return failedAVP, err } cer.appID = app.ID() return nil, nil }
// Handler for meta functions func metaHandler(m *diam.Message, tag, arg string, dur time.Duration) (string, error) { switch tag { case META_CCR_USAGE: var ok bool var reqType datatype.Enumerated var reqNr, reqUnit, usedUnit datatype.Unsigned32 if ccReqTypeAvp, err := m.FindAVP("CC-Request-Type", 0); err != nil { return "", err } else if ccReqTypeAvp == nil { return "", errors.New("CC-Request-Type not found") } else if reqType, ok = ccReqTypeAvp.Data.(datatype.Enumerated); !ok { return "", fmt.Errorf("CC-Request-Type must be Enumerated and not %v", ccReqTypeAvp.Data.Type()) } if ccReqNrAvp, err := m.FindAVP("CC-Request-Number", 0); err != nil { return "", err } else if ccReqNrAvp == nil { return "", errors.New("CC-Request-Number not found") } else if reqNr, ok = ccReqNrAvp.Data.(datatype.Unsigned32); !ok { return "", fmt.Errorf("CC-Request-Number must be Unsigned32 and not %v", ccReqNrAvp.Data.Type()) } switch reqType { case datatype.Enumerated(1), datatype.Enumerated(2): if reqUnitAVPs, err := m.FindAVPsWithPath([]interface{}{"Requested-Service-Unit", "CC-Time"}, dict.UndefinedVendorID); err != nil { return "", err } else if len(reqUnitAVPs) == 0 { return "", errors.New("Requested-Service-Unit>CC-Time not found") } else if reqUnit, ok = reqUnitAVPs[0].Data.(datatype.Unsigned32); !ok { return "", fmt.Errorf("Requested-Service-Unit>CC-Time must be Unsigned32 and not %v", reqUnitAVPs[0].Data.Type()) } case datatype.Enumerated(3), datatype.Enumerated(4): if usedUnitAVPs, err := m.FindAVPsWithPath([]interface{}{"Used-Service-Unit", "CC-Time"}, dict.UndefinedVendorID); err != nil { return "", err } else if len(usedUnitAVPs) != 0 { if usedUnit, ok = usedUnitAVPs[0].Data.(datatype.Unsigned32); !ok { return "", fmt.Errorf("Used-Service-Unit>CC-Time must be Unsigned32 and not %v", usedUnitAVPs[0].Data.Type()) } } } usage := usageFromCCR(int(reqType), int(reqNr), int(reqUnit), int(usedUnit), dur) return strconv.FormatFloat(usage.Seconds(), 'f', -1, 64), nil } return "", nil }
// AsBareDiameterMessage converts CCA into a bare DiameterMessage func (self *CCA) AsBareDiameterMessage() *diam.Message { var m diam.Message utils.Clone(self.diamMessage, &m) m.NewAVP(avp.SessionID, avp.Mbit, 0, datatype.UTF8String(self.SessionId)) m.NewAVP(avp.OriginHost, avp.Mbit, 0, datatype.DiameterIdentity(self.OriginHost)) m.NewAVP(avp.OriginRealm, avp.Mbit, 0, datatype.DiameterIdentity(self.OriginRealm)) m.NewAVP(avp.AuthApplicationID, avp.Mbit, 0, datatype.Unsigned32(self.AuthApplicationId)) m.NewAVP(avp.CCRequestType, avp.Mbit, 0, datatype.Enumerated(self.CCRequestType)) m.NewAVP(avp.CCRequestNumber, avp.Mbit, 0, datatype.Enumerated(self.CCRequestNumber)) m.NewAVP(avp.ResultCode, avp.Mbit, 0, datatype.Unsigned32(self.ResultCode)) return &m }
// avpsWithPath is used to find AVPs by specifying RSRField as filter func avpsWithPath(m *diam.Message, rsrFld *utils.RSRField) ([]*diam.AVP, error) { return m.FindAVPsWithPath(splitIntoInterface(rsrFld.Id, utils.HIERARCHY_SEP), dict.UndefinedVendorID) }
func (self *DiameterClient) SendMessage(m *diam.Message) error { _, err := m.WriteTo(self.conn) return err }
func sendCEA(w io.Writer, m *diam.Message, OriginStateID, AcctApplicationID *diam.AVP) (n int64, err error) { m.NewAVP(avp.OriginHost, avp.Mbit, 0, datatype.OctetString("srv")) m.NewAVP(avp.OriginRealm, avp.Mbit, 0, datatype.OctetString("localhost")) m.NewAVP(avp.HostIPAddress, avp.Mbit, 0, datatype.Address(net.ParseIP("127.0.0.1"))) m.NewAVP(avp.VendorID, avp.Mbit, 0, datatype.Unsigned32(99)) m.NewAVP(avp.ProductName, avp.Mbit, 0, datatype.UTF8String("go-diameter")) m.AddAVP(OriginStateID) m.AddAVP(AcctApplicationID) return m.WriteTo(w) }
// errorCEA sends an error answer indicating that the CER failed due to // an unsupported (acct/auth) application, and includes the AVP that // caused the failure in the message. func errorCEA(sm *StateMachine, c diam.Conn, m *diam.Message, cer *smparser.CER, failedAVP *diam.AVP) error { hostIP, _, err := net.SplitHostPort(c.LocalAddr().String()) if err != nil { return fmt.Errorf("failed to parse own ip %q: %s", c.LocalAddr(), err) } var a *diam.Message if failedAVP == cer.InbandSecurityID { a = m.Answer(diam.NoCommonSecurity) } else { a = m.Answer(diam.NoCommonApplication) } a.Header.CommandFlags |= diam.ErrorFlag a.NewAVP(avp.OriginHost, avp.Mbit, 0, sm.cfg.OriginHost) a.NewAVP(avp.OriginRealm, avp.Mbit, 0, sm.cfg.OriginRealm) a.NewAVP(avp.HostIPAddress, avp.Mbit, 0, datatype.Address(net.ParseIP(hostIP))) a.NewAVP(avp.VendorID, avp.Mbit, 0, sm.cfg.VendorID) a.NewAVP(avp.ProductName, 0, 0, sm.cfg.ProductName) if cer.OriginStateID != nil { a.AddAVP(cer.OriginStateID) } a.NewAVP(avp.FailedAVP, avp.Mbit, 0, &diam.GroupedAVP{ AVP: []*diam.AVP{failedAVP}, }) if sm.cfg.FirmwareRevision != 0 { a.NewAVP(avp.FirmwareRevision, avp.Mbit, 0, sm.cfg.FirmwareRevision) } _, err = a.WriteTo(c) return err }
func sendACR(c diam.Conn, cfg *sm.Settings, n int) { // Get this client's metadata from the connection object, // which is set by the state machine after the handshake. // It contains the peer's Origin-Host and Realm from the // CER/CEA handshake. We use it to populate the AVPs below. meta, ok := smpeer.FromContext(c.Context()) if !ok { log.Fatal("Client connection does not contain metadata") } var err error var m *diam.Message for i := 0; i < n; i++ { m = diam.NewRequest(diam.Accounting, 0, c.Dictionary()) m.NewAVP(avp.SessionID, avp.Mbit, 0, datatype.UTF8String(strconv.Itoa(i))) m.NewAVP(avp.OriginHost, avp.Mbit, 0, cfg.OriginHost) m.NewAVP(avp.OriginRealm, avp.Mbit, 0, cfg.OriginRealm) m.NewAVP(avp.DestinationRealm, avp.Mbit, 0, meta.OriginRealm) m.NewAVP(avp.AccountingRecordType, avp.Mbit, 0, eventRecord) m.NewAVP(avp.AccountingRecordNumber, avp.Mbit, 0, datatype.Unsigned32(i)) m.NewAVP(avp.DestinationHost, avp.Mbit, 0, meta.OriginHost) if _, err = m.WriteTo(c); err != nil { log.Fatal(err) } } }
// Parse parses the given message. func (dwa *DWA) Parse(m *diam.Message) error { if err := m.Unmarshal(dwa); err != nil { return err } return nil }