/* Real-time Online/Offline Charging System (OCS) for Telecom & ISP environments Copyright (C) ITsysCOM GmbH This program is free software: you can redistribute it and/or modify it under the terms of the GNU Affero General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. You should have received a copy of the GNU Affero General Public License along with this program. If not, see */ package agents import ( "errors" "fmt" "net" "os" "path/filepath" "strconv" "strings" "time" "github.com/cgrates/cgrates/config" "github.com/cgrates/cgrates/engine" "github.com/cgrates/cgrates/utils" "github.com/cgrates/go-diameter/diam" "github.com/cgrates/go-diameter/diam/avp" "github.com/cgrates/go-diameter/diam/datatype" "github.com/cgrates/go-diameter/diam/dict" ) func loadDictionaries(dictsDir, componentID string) error { fi, err := os.Stat(dictsDir) if err != nil { if strings.HasSuffix(err.Error(), "no such file or directory") { return fmt.Errorf("<%s> Invalid dictionaries folder: <%s>", componentID, dictsDir) } return err } else if !fi.IsDir() { // If config dir defined, needs to exist return fmt.Errorf(" Path: <%s> is not a directory", dictsDir) } return filepath.Walk(dictsDir, func(path string, info os.FileInfo, err error) error { if !info.IsDir() { return nil } cfgFiles, err := filepath.Glob(filepath.Join(path, "*.xml")) // Only consider .xml files if err != nil { return err } if cfgFiles == nil { // No need of processing further since there are no dictionary files in the folder return nil } for _, filePath := range cfgFiles { utils.Logger.Info(fmt.Sprintf("<%s> Loading dictionary out of file %s", componentID, filePath)) if err := dict.Default.LoadFile(filePath); err != nil { return err } } return nil }) } // diamAVPValue will extract the go primary value out of diameter type value func diamAVPAsIface(dAVP *diam.AVP) (val any, err error) { if dAVP == nil { return nil, errors.New("nil AVP") } switch dAVP.Data.Type() { default: return nil, fmt.Errorf("unsupported AVP data type: %d", dAVP.Data.Type()) case datatype.AddressType: return net.IP([]byte(dAVP.Data.(datatype.Address))), nil case datatype.DiameterIdentityType: return string(dAVP.Data.(datatype.DiameterIdentity)), nil case datatype.DiameterURIType: return string(dAVP.Data.(datatype.DiameterURI)), nil case datatype.EnumeratedType: return int32(dAVP.Data.(datatype.Enumerated)), nil case datatype.Float32Type: return float32(dAVP.Data.(datatype.Float32)), nil case datatype.Float64Type: return float64(dAVP.Data.(datatype.Float64)), nil case datatype.IPFilterRuleType: return string(dAVP.Data.(datatype.IPFilterRule)), nil case datatype.IPv4Type: return net.IP([]byte(dAVP.Data.(datatype.IPv4))), nil case datatype.Integer32Type: return int32(dAVP.Data.(datatype.Integer32)), nil case datatype.Integer64Type: return int64(dAVP.Data.(datatype.Integer64)), nil case datatype.OctetStringType: return string(dAVP.Data.(datatype.OctetString)), nil case datatype.QoSFilterRuleType: return string(dAVP.Data.(datatype.QoSFilterRule)), nil case datatype.TimeType: return time.Time(dAVP.Data.(datatype.Time)), nil case datatype.UTF8StringType: return string(dAVP.Data.(datatype.UTF8String)), nil case datatype.Unsigned32Type: return uint32(dAVP.Data.(datatype.Unsigned32)), nil case datatype.Unsigned64Type: return uint64(dAVP.Data.(datatype.Unsigned64)), nil } } func diamAVPAsString(dAVP *diam.AVP) (s string, err error) { var iface any if iface, err = diamAVPAsIface(dAVP); err != nil { return } return utils.IfaceAsString(iface), nil } // newDiamDataType constructs dataType from valStr func newDiamDataType(typ datatype.TypeID, valStr, tmz string) (dt datatype.Type, err error) { switch typ { default: return nil, fmt.Errorf("unsupported AVP data type: %d", typ) case datatype.AddressType: return datatype.Address(net.ParseIP(valStr)), nil case datatype.DiameterIdentityType: return datatype.DiameterIdentity(valStr), nil case datatype.DiameterURIType: return datatype.DiameterURI(valStr), nil case datatype.EnumeratedType: i, err := strconv.ParseInt(valStr, 10, 32) if err != nil { return nil, err } return datatype.Enumerated(int32(i)), nil case datatype.Float32Type: f, err := strconv.ParseFloat(valStr, 32) if err != nil { return nil, err } return datatype.Float32(float32(f)), nil case datatype.Float64Type: f, err := strconv.ParseFloat(valStr, 64) if err != nil { return nil, err } return datatype.Float64(f), nil case datatype.IPFilterRuleType: return datatype.IPFilterRule(valStr), nil case datatype.IPv4Type: return datatype.IPv4(net.ParseIP(valStr)), nil case datatype.Integer32Type: i, err := strconv.ParseInt(valStr, 10, 32) if err != nil { return nil, err } return datatype.Integer32(int32(i)), nil case datatype.Integer64Type: i, err := strconv.ParseInt(valStr, 10, 64) if err != nil { return nil, err } return datatype.Integer64(i), nil case datatype.OctetStringType: return datatype.OctetString(valStr), nil case datatype.QoSFilterRuleType: return datatype.QoSFilterRule(valStr), nil case datatype.TimeType: t, err := utils.ParseTimeDetectLayout(valStr, tmz) if err != nil { return nil, err } return datatype.Time(t), nil case datatype.UTF8StringType: return datatype.UTF8String(valStr), nil case datatype.Unsigned32Type: i, err := strconv.ParseUint(valStr, 10, 32) if err != nil { return nil, err } return datatype.Unsigned32(uint32(i)), nil case datatype.Unsigned64Type: i, err := strconv.ParseUint(valStr, 10, 64) if err != nil { return nil, err } return datatype.Unsigned64(i), nil } } func headerLen(a *diam.AVP) int { if a.Flags&avp.Vbit == avp.Vbit { return 12 } return 8 } func updateAVPLength(avps []*diam.AVP) (l int) { for _, avp := range avps { if v, ok := (avp.Data).(*diam.GroupedAVP); ok { avp.Length = headerLen(avp) + updateAVPLength(v.AVP) } l += avp.Length } return } // messageAddAVPsWithPath will dynamically add AVPs into the message // // append: append to the message, on false overwrite if AVP is single or add to group if AVP is Grouped func messageSetAVPsWithPath(m *diam.Message, pathStr []string, avpValStr string, newBranch bool, tmz string) (err error) { if len(pathStr) == 0 { return errors.New("empty path as AVP filter") } path := utils.SliceStringToIface(pathStr) dictAVPs := make([]*dict.AVP, len(path)) // for each subpath, one dictionary AVP for i, subpath := range path { if dictAVP, err := m.Dictionary().FindAVP(m.Header.ApplicationID, subpath); err != nil { return err } else if dictAVP == nil { return fmt.Errorf("cannot find AVP with id: %s", path[len(path)-1]) } else { dictAVPs[i] = dictAVP } } lastAVPIdx := len(path) - 1 if dictAVPs[lastAVPIdx].Data.Type == diam.GroupedAVPType { return errors.New("last AVP in path cannot be GroupedAVP") } var msgAVP *diam.AVP // Keep a reference here towards last AVP for i := lastAVPIdx; i >= 0; i-- { var typeVal datatype.Type if i == lastAVPIdx { if typeVal, err = newDiamDataType(dictAVPs[i].Data.Type, avpValStr, tmz); err != nil { return err } } else { typeVal = &diam.GroupedAVP{ AVP: []*diam.AVP{msgAVP}} } msgAVP = diam.NewAVP(dictAVPs[i].Code, avp.Mbit, dictAVPs[i].VendorID, typeVal) // FixMe: maybe Mbit with dictionary one if i > 0 && !newBranch { avps, err := m.FindAVPsWithPath(path[:i], dict.UndefinedVendorID) if err != nil { return err } if len(avps) != 0 { // Group AVP already in the message prevGrpData, ok := avps[len(avps)-1].Data.(*diam.GroupedAVP) // Take the last avp found to append there if ok { prevGrpData.AVP = append(prevGrpData.AVP, msgAVP) m.Header.MessageLength += uint32(msgAVP.Len()) // updateAVPLenght(m.AVP) return nil } } } } if !newBranch { // Not group AVP, replace the previous set one with this one avps, err := m.FindAVPsWithPath(path, dict.UndefinedVendorID) if err != nil { return err } if len(avps) != 0 { // Group AVP already in the message m.Header.MessageLength -= uint32(avps[len(avps)-1].Len()) // decrease message length since we overwrite *avps[len(avps)-1] = *msgAVP m.Header.MessageLength += uint32(msgAVP.Len()) // updateAVPLenght(m.AVP) return nil } } m.AVP = append(m.AVP, msgAVP) m.Header.MessageLength += uint32(msgAVP.Len()) // updateAVPLenght(m.AVP) return nil } // writeOnConn writes the message on connection, logs failures func writeOnConn(c diam.Conn, m *diam.Message) (err error) { if _, err = m.WriteTo(c); err != nil { utils.Logger.Warning(fmt.Sprintf("<%s> failed writing message to %s, err: %s, msg: %s", utils.DiameterAgent, c.RemoteAddr(), err.Error(), m)) } return } // newDADataProvider constructs a DataProvider for a diameter message func newDADataProvider(c diam.Conn, m *diam.Message) utils.DataProvider { return &diameterDP{c: c, m: m, cache: utils.MapStorage{}} } // diameterDP implements utils.DataProvider, serving as diam.Message data decoder // decoded data is only searched once and cached type diameterDP struct { c diam.Conn m *diam.Message cache utils.MapStorage } // String is part of utils.DataProvider interface // when called, it will display the already parsed values out of cache func (dP *diameterDP) String() string { return dP.m.String() } // FieldAsString is part of utils.DataProvider interface func (dP *diameterDP) FieldAsString(fldPath []string) (data string, err error) { var valIface any valIface, err = dP.FieldAsInterface(fldPath) if err != nil { return } return utils.IfaceAsString(valIface), nil } // FieldAsInterface is part of utils.DataProvider interface func (dP *diameterDP) FieldAsInterface(fldPath []string) (data any, err error) { if data, err = dP.cache.FieldAsInterface(fldPath); err != nil { if err != utils.ErrNotFound { // item found in cache return nil, err } err = nil // cancel previous err } else { return // data was found in cache } // lastPath can contain selector inside lastPath := fldPath[len(fldPath)-1] var slctrStr string if splt := strings.Split(lastPath, utils.IdxStart); len(splt) != 1 { lastPath = splt[0] if splt[1][len(splt[1])-1:] != utils.IdxEnd { return nil, fmt.Errorf("filter rule <%s> needs to end in ]", splt[1]) } slctrStr = splt[1][:len(splt[1])-1] // also strip the last ] } pathIface := utils.SliceStringToIface(fldPath) if slctrStr != "" { // last path was having selector inside before pathIface[len(pathIface)-1] = lastPath // last path was changed } var avps []*diam.AVP if avps, err = dP.m.FindAVPsWithPath( pathIface, dict.UndefinedVendorID); err != nil { return nil, err } else if len(avps) == 0 { return nil, utils.ErrNotFound } slectedIdx := 0 // by default we select AVP[0] if slctrStr != "" { if slectedIdx, err = strconv.Atoi(slctrStr); err != nil { // not int, compile it as RSRParser selIndxs := make(map[int]int) // use it to find intersection of all matched filters slctrStrs := strings.Split(slctrStr, utils.PipeSep) for _, slctrStr := range slctrStrs { var fltrs utils.RSRFilters if strings.HasSuffix(slctrStr, utils.FilterValEnd) { // Has filter, populate the var fltrStart := strings.Index(slctrStr, utils.FilterValStart) if fltrStart < 1 { return nil, fmt.Errorf("invalid RSRFilter start rule in string: <%s>", slctrStr) } fltrVal := slctrStr[fltrStart+1 : len(slctrStr)-1] if fltrs, err = utils.ParseRSRFilters(fltrVal, utils.ANDSep); err != nil { return nil, fmt.Errorf("Invalid FilterValue in string: %s, err: %s", fltrVal, err.Error()) } slctrStr = slctrStr[:fltrStart] // Take the filter part out before compiling further } slctr, err := config.NewRSRParser(slctrStr) if err != nil { return nil, err } var fltrAVPs []*diam.AVP for i := len(pathIface) - 1; i > 0; i-- { pathIface[i] = slctr.AttrName() // search for AVPs which are having common path but different end element pathIface = pathIface[:i+1] if fltrAVPs, err = dP.m.FindAVPsWithPath(pathIface, dict.UndefinedVendorID); err != nil || len(fltrAVPs) != 0 { break // found AVPs or got error, go and compare } } if err != nil { return nil, err } else if len(fltrAVPs) == 0 || len(fltrAVPs) != len(avps) { return nil, utils.ErrNotFound } for k, fAVP := range fltrAVPs { if dataAVP, err := diamAVPAsIface(fAVP); err != nil { return nil, err } else if fld, err := slctr.ParseValue(dataAVP); err != nil { if err == utils.ErrNotFound && fltrs.FilterRules() == "^$" { selIndxs[k+1]++ // filter passing, index it with one higher to cover 0 continue // filter not passing, not really error } return nil, err } else if fltrs.Pass(fld, true) { selIndxs[k+1]++ // filter passing, index it with one higher to cover 0 } } } var oneMatches bool for k, matches := range selIndxs { if matches == len(slctrStrs) { // all filters in selection matching oneMatches = true slectedIdx = k - 1 // decrease it to reflect real index break } } if !oneMatches { return nil, utils.ErrNotFound // return NotFound and let the other subsystem handle it (e.g. FilterS ) } } } if slectedIdx >= len(avps) { return nil, errors.New("avp index higher than number of AVPs") } if data, err = diamAVPAsIface(avps[slectedIdx]); err != nil { return nil, err } dP.cache.Set(fldPath, data) return } // updateDiamMsgFromNavMap will update the diameter message with items from navigable map func updateDiamMsgFromNavMap(m *diam.Message, navMp *utils.OrderedNavigableMap, tmz string) (err error) { // write reply into message for el := navMp.GetFirstElement(); el != nil; el = el.Next() { path := el.Value nmIt, _ := navMp.Field(path) if nmIt == nil { continue // all attributes, not writable to diameter packet } path = path[:len(path)-1] // remove the last index if err = messageSetAVPsWithPath(m, path, nmIt.String(), nmIt.NewBranch, tmz); err != nil { return fmt.Errorf("setting item with path: %+v got err: %s", path, err.Error()) } } return } // diamAnswer builds up the answer to be sent back to the client func diamAnswer(m *diam.Message, resCode uint32, errFlag bool, rply *utils.OrderedNavigableMap, tmz string) (a *diam.Message, err error) { a = m.Answer(resCode) if errFlag { a.Header.CommandFlags = diam.ErrorFlag } if err = updateDiamMsgFromNavMap(a, rply, tmz); err != nil { return nil, err } return } // diamErr handles Diameter error scenarios by attempting to build a customized error answer // based on the *err template. func diamErr(c diam.Conn, m *diam.Message, resCode uint32, reqVars *utils.DataNode, cfg *config.CGRConfig, filterS *engine.FilterS) { tnt := cfg.GeneralCfg().DefaultTenant tmz := cfg.GeneralCfg().DefaultTimezone aReq := NewAgentRequest(newDADataProvider(nil, m), reqVars, nil, nil, nil, nil, tnt, tmz, filterS, nil) if err := aReq.SetFields(cfg.TemplatesCfg()[utils.MetaErr]); err != nil { utils.Logger.Warning(fmt.Sprintf( "<%s> message: %s - failed to parse *err template: %v", utils.DiameterAgent, m, err)) writeOnConn(c, diamErrMsg(m, diam.UnableToComply, fmt.Sprintf("failed to parse *err template: %v", err))) return } diamAns, err := diamAnswer(m, resCode, true, aReq.Reply, tmz) if err != nil { utils.Logger.Warning(fmt.Sprintf( "<%s> message: %s - failed to build error answer: %v", utils.DiameterAgent, m, err)) writeOnConn(c, diamErrMsg(m, diam.UnableToComply, fmt.Sprintf("failed to build error answer: %v", err))) return } writeOnConn(c, diamAns) } // diamErrMsg creates a Diameter error answer with the given result code and optional error message. func diamErrMsg(m *diam.Message, resCode uint32, msg string) *diam.Message { ans := m.Answer(resCode) ans.Header.CommandFlags = diam.ErrorFlag if msg != "" { ans.NewAVP(avp.ErrorMessage, 0, 0, datatype.UTF8String(msg)) } return ans } func disectDiamListen(addrs string) (ipAddrs []net.IP) { ipPort := strings.Split(addrs, utils.InInFieldSep) if ipPort[0] == "" { return } ips := strings.Split(ipPort[0], utils.HDRValSep) ipAddrs = make([]net.IP, len(ips)) for i, ip := range ips { ipAddrs[i] = net.ParseIP(ip) } return } // diamMessageData is cached when data is needed (ie. ) type diamMsgData struct { c diam.Conn m *diam.Message vars *utils.DataNode }