From 315eddb63fb11973e30596d52050668fc68eecc2 Mon Sep 17 00:00:00 2001 From: Radu Ioan Fericean Date: Wed, 19 Mar 2014 17:32:02 +0200 Subject: [PATCH] moved locking before debit --- engine/account.go | 140 ++++++++++++++++++++++++---------------- engine/calldesc.go | 106 ++++++++++++++++++------------ engine/calldesc_test.go | 2 +- 3 files changed, 149 insertions(+), 99 deletions(-) diff --git a/engine/account.go b/engine/account.go index dfe1a6a8b..58f18fc2e 100644 --- a/engine/account.go +++ b/engine/account.go @@ -269,37 +269,36 @@ func (ub *Account) debitMinutesFromSharedBalances(myBalance *Balance, cc *CallCo return } sharingMembers := sharedGroup.Members - AccLock.GuardMany(sharedGroup.GetMembersExceptUser(ub.Id), func() (float64, error) { - var allMinuteSharedBalances BalanceChain - for _, ubId := range sharingMembers { - var nUb *Account - if ubId == ub.Id { // skip the initiating user - nUb = ub - } else { - nUb, err = accountingStorage.GetAccount(ubId) - if err != nil { - Logger.Warning(fmt.Sprintf("Could not get user balance: %s", ubId)) - } - if nUb.Disabled { - Logger.Warning(fmt.Sprintf("Disabled user in shared group: %s (%s)", ubId, sharedGroupName)) - continue - } + + var allMinuteSharedBalances BalanceChain + for _, ubId := range sharingMembers { + var nUb *Account + if ubId == ub.Id { // skip the initiating user + nUb = ub + } else { + nUb, err = accountingStorage.GetAccount(ubId) + if err != nil { + Logger.Warning(fmt.Sprintf("Could not get user balance: %s", ubId)) } - sharedMinuteBalances := nUb.getBalancesForPrefix(cc.Destination, nUb.BalanceMap[MINUTES+cc.Direction], sharedGroupName) - allMinuteSharedBalances = append(allMinuteSharedBalances, sharedMinuteBalances...) - } - for _, sharedBalance := range sharedGroup.GetBalancesByStrategy(myBalance, allMinuteSharedBalances) { - initialValue := sharedBalance.Value - sharedBalance.DebitMinutes(cc, count, sharedBalance.account, moneyBalances) - if sharedBalance.Value != initialValue { - accountingStorage.SetAccount(sharedBalance.account) - } - if cc.IsPaid() { - return 0, nil + if nUb.Disabled { + Logger.Warning(fmt.Sprintf("Disabled user in shared group: %s (%s)", ubId, sharedGroupName)) + continue } } - return 0, nil - }) + sharedMinuteBalances := nUb.getBalancesForPrefix(cc.Destination, nUb.BalanceMap[MINUTES+cc.Direction], sharedGroupName) + allMinuteSharedBalances = append(allMinuteSharedBalances, sharedMinuteBalances...) + } + for _, sharedBalance := range sharedGroup.GetBalancesByStrategy(myBalance, allMinuteSharedBalances) { + initialValue := sharedBalance.Value + sharedBalance.DebitMinutes(cc, count, sharedBalance.account, moneyBalances) + if sharedBalance.Value != initialValue { + accountingStorage.SetAccount(sharedBalance.account) + } + if cc.IsPaid() { + return + } + } + return } func (ub *Account) debitMoneyFromSharedBalances(myBalance *Balance, cc *CallCost, count bool) { @@ -310,37 +309,35 @@ func (ub *Account) debitMoneyFromSharedBalances(myBalance *Balance, cc *CallCost return } sharingMembers := sharedGroup.Members - AccLock.GuardMany(sharedGroup.GetMembersExceptUser(ub.Id), func() (float64, error) { - var allMoneySharedBalances BalanceChain - for _, ubId := range sharingMembers { - var nUb *Account - if ubId == ub.Id { // skip the initiating user - nUb = ub - } else { - nUb, err = accountingStorage.GetAccount(ubId) - if err != nil { - Logger.Warning(fmt.Sprintf("Could not get user balance: %s", ubId)) - } - if nUb.Disabled { - Logger.Warning(fmt.Sprintf("Disabled user in shared group: %s (%s)", ubId, sharedGroupName)) - continue - } + var allMoneySharedBalances BalanceChain + for _, ubId := range sharingMembers { + var nUb *Account + if ubId == ub.Id { // skip the initiating user + nUb = ub + } else { + nUb, err = accountingStorage.GetAccount(ubId) + if err != nil { + Logger.Warning(fmt.Sprintf("Could not get user balance: %s", ubId)) } - sharedMoneyBalances := nUb.getBalancesForPrefix(cc.Destination, nUb.BalanceMap[CREDIT+cc.Direction], sharedGroupName) - allMoneySharedBalances = append(allMoneySharedBalances, sharedMoneyBalances...) - } - for _, sharedBalance := range sharedGroup.GetBalancesByStrategy(myBalance, allMoneySharedBalances) { - initialValue := sharedBalance.Value - sharedBalance.DebitMoney(cc, count, sharedBalance.account) - if sharedBalance.Value != initialValue { - accountingStorage.SetAccount(sharedBalance.account) - } - if cc.IsPaid() { - return 0, nil + if nUb.Disabled { + Logger.Warning(fmt.Sprintf("Disabled user in shared group: %s (%s)", ubId, sharedGroupName)) + continue } } - return 0, nil - }) + sharedMoneyBalances := nUb.getBalancesForPrefix(cc.Destination, nUb.BalanceMap[CREDIT+cc.Direction], sharedGroupName) + allMoneySharedBalances = append(allMoneySharedBalances, sharedMoneyBalances...) + } + for _, sharedBalance := range sharedGroup.GetBalancesByStrategy(myBalance, allMoneySharedBalances) { + initialValue := sharedBalance.Value + sharedBalance.DebitMoney(cc, count, sharedBalance.account) + if sharedBalance.Value != initialValue { + accountingStorage.SetAccount(sharedBalance.account) + } + if cc.IsPaid() { + return + } + } + return } func (ub *Account) GetDefaultMoneyBalance(direction string) *Balance { @@ -533,3 +530,34 @@ func (ub *Account) GetSharedGroups() (groups []string) { } return } + +func (account *Account) GetUniqueSharedGroupMembers(destination, direction string) ([]string, error) { + creditBalances := account.getBalancesForPrefix(destination, account.BalanceMap[CREDIT+direction], "") + minuteBalances := account.getBalancesForPrefix(destination, account.BalanceMap[MINUTES+direction], "") + // gather all shared group ids + var sharedGroupIds []string + for _, cb := range creditBalances { + if cb.SharedGroup != "" { + sharedGroupIds = append(sharedGroupIds, cb.SharedGroup) + } + } + for _, mb := range minuteBalances { + if mb.SharedGroup != "" { + sharedGroupIds = append(sharedGroupIds, mb.SharedGroup) + } + } + var memberIds []string + for _, sgID := range sharedGroupIds { + sharedGroup, err := accountingStorage.GetSharedGroup(sgID, false) + if err != nil { + Logger.Warning(fmt.Sprintf("Could not get shared group: %v", sgID)) + return nil, err + } + for _, memberId := range sharedGroup.GetMembersExceptUser(account.Id) { + if !utils.IsSliceMember(memberIds, memberId) { + memberIds = append(memberIds, memberId) + } + } + } + return memberIds, nil +} diff --git a/engine/calldesc.go b/engine/calldesc.go index c1bb509f2..4e916ae7c 100644 --- a/engine/calldesc.go +++ b/engine/calldesc.go @@ -434,7 +434,7 @@ Returns the approximate max allowed session for user balance. It will try the ma If the user has no credit then it will return 0. If the user has postpayed plan it returns -1. */ -func (origCD *CallDescriptor) GetMaxSessionDuration() (time.Duration, error) { +func (origCD *CallDescriptor) getMaxSessionDuration(account *Account) (time.Duration, error) { if origCD.CallDuration < origCD.TimeEnd.Sub(origCD.TimeStart) { origCD.CallDuration = origCD.TimeEnd.Sub(origCD.TimeStart) } @@ -447,16 +447,11 @@ func (origCD *CallDescriptor) GetMaxSessionDuration() (time.Duration, error) { } var availableDuration time.Duration availableCredit := 0.0 - if userBalance, err := cd.getAccount(); err == nil && userBalance != nil { - if userBalance.AllowNegative { - return -1, nil - } else { - availableDuration, availableCredit, _ = userBalance.getCreditForPrefix(cd) - // Logger.Debug(fmt.Sprintf("available sec: %v credit: %v", availableSeconds, availableCredit)) - } + if account.AllowNegative { + return -1, nil } else { - Logger.Err(fmt.Sprintf("Could not get user balance for %s: %s.", cd.GetAccountKey(), err.Error())) - return 0, err + availableDuration, availableCredit, _ = account.getCreditForPrefix(cd) + // Logger.Debug(fmt.Sprintf("available sec: %v credit: %v", availableSeconds, availableCredit)) } //Logger.Debug(fmt.Sprintf("availableDuration: %v, availableCredit: %v", availableDuration, availableCredit)) initialDuration := cd.TimeEnd.Sub(cd.TimeStart) @@ -501,55 +496,82 @@ func (origCD *CallDescriptor) GetMaxSessionDuration() (time.Duration, error) { return utils.MinDuration(initialDuration, availableDuration), nil } +func (origCD *CallDescriptor) GetMaxSessionDuration() (time.Duration, error) { + cd := origCD.Clone() + if account, err := cd.getAccount(); err != nil || account == nil { + Logger.Err(fmt.Sprintf("Could not get user balance for %s: %s.", cd.GetAccountKey(), err.Error())) + return 0, err + } else { + return cd.getMaxSessionDuration(account) + } +} + // Interface method used to add/substract an amount of cents or bonus seconds (as returned by GetCost method) // from user's money balance. -func (cd *CallDescriptor) Debit() (cc *CallCost, err error) { +func (cd *CallDescriptor) debit(account *Account) (cc *CallCost, err error) { cc, err = cd.GetCost() if err != nil { Logger.Err(fmt.Sprintf(" Error getting cost for account key %v: %v", cd.GetAccountKey(), err)) return } - if userBalance, err := cd.getAccount(); err != nil { - Logger.Err(fmt.Sprintf(" Error retrieving user balance: %v", err)) - } else if userBalance == nil { - // Logger.Debug(fmt.Sprintf(" No user balance defined: %v", cd.GetAccountKey())) - } else { - //Logger.Debug(fmt.Sprintf(" Attempting to debit from %v, value: %v", cd.GetAccountKey(), cc.Cost+cc.ConnectFee)) - defer accountingStorage.SetAccount(userBalance) - //ub, _ := json.Marshal(userBalance) - //Logger.Debug(fmt.Sprintf("Account: %s", ub)) - //cCost, _ := json.Marshal(cc) - //Logger.Debug(fmt.Sprintf("CallCost: %s", cCost)) - if cc.Cost != 0 || (cc.deductConnectFee && cc.GetConnectFee() != 0) { - userBalance.debitCreditBalance(cc, true) - } - cost := 0.0 - // re-calculate call cost after balances - if cc.deductConnectFee { // add back the connectFee - cost += cc.GetConnectFee() - } - for _, ts := range cc.Timespans { - cost += ts.getCost() - cost = utils.Round(cost, roundingDecimals, utils.ROUNDING_MIDDLE) // just get rid of the extra decimals - } - cc.Cost = cost + //Logger.Debug(fmt.Sprintf(" Attempting to debit from %v, value: %v", cd.GetAccountKey(), cc.Cost+cc.ConnectFee)) + defer accountingStorage.SetAccount(account) + //ub, _ := json.Marshal(account) + //Logger.Debug(fmt.Sprintf("Account: %s", ub)) + //cCost, _ := json.Marshal(cc) + //Logger.Debug(fmt.Sprintf("CallCost: %s", cCost)) + if cc.Cost != 0 || (cc.deductConnectFee && cc.GetConnectFee() != 0) { + account.debitCreditBalance(cc, true) } + cost := 0.0 + // re-calculate call cost after balances + if cc.deductConnectFee { // add back the connectFee + cost += cc.GetConnectFee() + } + for _, ts := range cc.Timespans { + cost += ts.getCost() + cost = utils.Round(cost, roundingDecimals, utils.ROUNDING_MIDDLE) // just get rid of the extra decimals + } + cc.Cost = cost return } +func (cd *CallDescriptor) Debit() (cc *CallCost, err error) { + // lock all group members + if account, err := cd.getAccount(); err != nil || account == nil { + Logger.Err(fmt.Sprintf("Could not get user balance for %s: %s.", cd.GetAccountKey(), err.Error())) + return nil, err + } else { + if memberIds, err := account.GetUniqueSharedGroupMembers(cd.Destination, cd.Direction); err == nil { + AccLock.GuardMany(memberIds, func() (float64, error) { + cc, err = cd.debit(account) + return 0, err + }) + } else { + return nil, err + } + return cc, err + } +} + // Interface method used to add/substract an amount of cents or bonus seconds (as returned by GetCost method) // from user's money balance. // This methods combines the Debit and GetMaxSessionDuration and will debit the max available time as returned // by the GetMaxSessionTime method. The amount filed has to be filled in call descriptor. func (cd *CallDescriptor) MaxDebit() (cc *CallCost, err error) { - remainingDuration, err := cd.GetMaxSessionDuration() - if err != nil || remainingDuration == 0 { - return new(CallCost), fmt.Errorf("no more credit: %v", err) + if account, err := cd.getAccount(); err != nil || account == nil { + Logger.Err(fmt.Sprintf("Could not get user balance for %s: %s.", cd.GetAccountKey(), err.Error())) + return nil, err + } else { + remainingDuration, err := cd.getMaxSessionDuration(account) + if err != nil || remainingDuration == 0 { + return new(CallCost), fmt.Errorf("no more credit: %v", err) + } + if remainingDuration > 0 { // for postpaying client returns -1 + cd.TimeEnd = cd.TimeStart.Add(remainingDuration) + } + return cd.debit(account) } - if remainingDuration > 0 { // for postpaying client returns -1 - cd.TimeEnd = cd.TimeStart.Add(remainingDuration) - } - return cd.Debit() } func (cd *CallDescriptor) RefundIncrements() (left float64, err error) { diff --git a/engine/calldesc_test.go b/engine/calldesc_test.go index 4f7e32eca..f8d397b36 100644 --- a/engine/calldesc_test.go +++ b/engine/calldesc_test.go @@ -335,7 +335,7 @@ func TestMaxSessionTimeWithAccountAlias(t *testing.T) { result, err := cd.GetMaxSessionDuration() expected := time.Minute if result != expected || err != nil { - t.Errorf("Expected %v was %v", expected, result) + t.Errorf("Expected %v was %v, %v", expected, result, err) } }