From b6cf9dda4d331a02d9564b0a3eee13b66195e6b1 Mon Sep 17 00:00:00 2001 From: DanB Date: Fri, 18 Dec 2015 19:02:54 +0100 Subject: [PATCH] Diameter messageAddAVPsWithPath --- agents/libdmt.go | 58 +++++++++++++++++++++++++++++++++++++------ agents/libdmt_test.go | 15 +++++++++++ 2 files changed, 66 insertions(+), 7 deletions(-) diff --git a/agents/libdmt.go b/agents/libdmt.go index 39766b378..778de367c 100644 --- a/agents/libdmt.go +++ b/agents/libdmt.go @@ -204,13 +204,19 @@ func metaHandler(m *diam.Message, tag, arg string, debitInterval time.Duration) return "", nil } -func avpsWithPath(m *diam.Message, rsrFld *utils.RSRField) ([]*diam.AVP, error) { - hierarchyPath := strings.Split(rsrFld.Id, utils.HIERARCHY_SEP) - hpIf := make([]interface{}, len(hierarchyPath)) - for i, val := range hierarchyPath { - hpIf[i] = val +// splitIntoInterface is used to split a string into []interface{} instead of []string +func splitIntoInterface(content, sep string) []interface{} { + spltStr := strings.Split(content, sep) + spltIf := make([]interface{}, len(spltStr)) + for i, val := range spltStr { + spltIf[i] = val } - return m.FindAVPsWithPath(hpIf, dict.UndefinedVendorID) + return spltIf +} + +// avpsWithPath is used to find AVPs by specifying RSRField as filter +func avpsWithPath(m *diam.Message, rsrFld *utils.RSRField) ([]*diam.AVP, error) { + return m.FindAVPsWithPath(splitIntoInterface(rsrFld.Id, utils.HIERARCHY_SEP), dict.UndefinedVendorID) } // Follows the implementation in the StorCdr @@ -297,6 +303,44 @@ func fieldOutVal(m *diam.Message, cfgFld *config.CfgCdrField, debitInterval time return fmtValOut, nil } +// messageAddAVPsWithPath will dynamically add AVPs into the message +func messageAddAVPsWithPath(m *diam.Message, path []interface{}, avpValByte []byte) error { + if len(path) == 0 { + return errors.New("Empty path as AVP filter") + } + dictAVPs := make([]*dict.AVP, len(path)) // for each subpath, one dictionary AVP + for i, subpath := range path { + if dictAVP, err := m.Dictionary().FindAVP(m.Header.ApplicationID, subpath); err != nil { + return err + } else if dictAVP == nil { + return fmt.Errorf("Cannot find AVP with id: %s", path[len(path)-1]) + } else { + dictAVPs[i] = dictAVP + } + } + if dictAVPs[len(path)-1].Data.Type == diam.GroupedAVPType { + return errors.New("Last AVP in path needs not to be GroupedAVP") + } + var msgAVP *diam.AVP // Keep a reference here towards last AVP + for i := len(path) - 1; i >= 0; i-- { + var typeVal datatype.Type + var err error + if msgAVP == nil { + typeVal, err = datatype.Decode(dictAVPs[i].Data.Type, avpValByte) + if err != nil { + return err + } + } else { + typeVal = &diam.GroupedAVP{ + AVP: []*diam.AVP{msgAVP}} + } + msgAVP = diam.NewAVP(dictAVPs[i].Code, avp.Mbit, dictAVPs[i].VendorID, typeVal) // FixMe: maybe Mbit with dictionary one + } + m.AVP = append(m.AVP, msgAVP) + m.Header.MessageLength += uint32(msgAVP.Len()) + return nil +} + // debitInterval is the configured debitInterval, in sync with the diameter client one func NewCCRFromDiameterMessage(m *diam.Message, debitInterval time.Duration) (*CCR, error) { var ccr CCR @@ -455,7 +499,7 @@ func NewCCAFromCCR(ccr *CCR) *CCA { } } -// Call Control Answer +// Call Control Answer, bare structure so we can dynamically manage adding it's fields type CCA struct { SessionId string `avp:"Session-Id"` OriginHost string `avp:"Origin-Host"` diff --git a/agents/libdmt_test.go b/agents/libdmt_test.go index 49f28b5ed..68fdb608e 100644 --- a/agents/libdmt_test.go +++ b/agents/libdmt_test.go @@ -19,6 +19,7 @@ along with this program. If not, see package agents import ( + "reflect" "testing" "time" @@ -135,3 +136,17 @@ func TestFieldOutVal(t *testing.T) { t.Errorf("Expecting: %s, received: %s", eOut, fldOut) } } + +func TestMessageAddAVPsWithPath(t *testing.T) { + eMessage := diam.NewRequest(diam.CreditControl, 4, nil) + eMessage.NewAVP("Subscription-Id", avp.Mbit, 0, &diam.GroupedAVP{ + AVP: []*diam.AVP{ + diam.NewAVP(444, avp.Mbit, 0, datatype.UTF8String("33708000003")), // Subscription-Id-Data + }}) + m := diam.NewMessage(diam.CreditControl, diam.RequestFlag, 4, eMessage.Header.HopByHopID, eMessage.Header.EndToEndID, nil) + if err := messageAddAVPsWithPath(m, []interface{}{"Subscription-Id", "Subscription-Id-Data"}, []byte("33708000003")); err != nil { + t.Error(err) + } else if !reflect.DeepEqual(eMessage, m) { + t.Errorf("Expecting: %+v, received: %+v", eMessage, m) + } +}