diff --git a/guardian/guardian.go b/guardian/guardian.go index 9aaf418db..52045da30 100644 --- a/guardian/guardian.go +++ b/guardian/guardian.go @@ -38,17 +38,39 @@ type itemLock struct { sync.Mutex } +// lock() executes combined lock with increasing counter +func (il *itemLock) lock() { + atomic.AddInt64(&il.cnt, 1) + il.Lock() +} + // unlock() executes combined lock with autoremoving lock from Guardian func (il *itemLock) unlock() { atomic.AddInt64(&il.cnt, -1) - if atomic.LoadInt64(&il.cnt) == 0 { // last lock in the queue + cnt := atomic.LoadInt64(&il.cnt) + if cnt < 0 { // already unlocked + return + } + if cnt == 0 { // last lock in the queue Guardian.Lock() - if il.cnt == 0 { // assurance that our counter was not modified in between read and lock - delete(Guardian.locksMap, il.keyID) - } + delete(Guardian.locksMap, il.keyID) Guardian.Unlock() } - il.Unlock() // will unlock a single count so the next one waiting for lock can proceed + il.Unlock() +} + +type itemLocks []*itemLock + +func (ils itemLocks) lock() { + for _, itmLock := range ils { + itmLock.lock() + } +} + +func (ils itemLocks) unlock() { + for _, itmLock := range ils { + itmLock.unlock() + } } // GuardianLock is an optimized locking system per locking key @@ -59,7 +81,7 @@ type GuardianLock struct { // lockItems locks a set of lockIDs // returning the lock structs so they can be later unlocked -func (guard *GuardianLock) lockItems(lockIDs []string) (itmLocks []*itemLock) { +func (guard *GuardianLock) lockItems(lockIDs []string) (itmLocks itemLocks) { guard.Lock() for _, lockID := range lockIDs { var itmLock *itemLock @@ -68,23 +90,15 @@ func (guard *GuardianLock) lockItems(lockIDs []string) (itmLocks []*itemLock) { itmLock = newItemLock(lockID) guard.locksMap[lockID] = itmLock } - atomic.AddInt64(&itmLock.cnt, 1) itmLocks = append(itmLocks, itmLock) } guard.Unlock() - for _, itmLock := range itmLocks { - itmLock.Lock() - } + + itmLocks.lock() return } -// unlockItems will unlock the items provided -func (guard *GuardianLock) unlockItems(itmLocks []*itemLock) { - for _, itmLock := range itmLocks { - itmLock.unlock() - } -} - +// Guard executes the handler between locks func (guard *GuardianLock) Guard(handler func() (interface{}, error), timeout time.Duration, lockIDs ...string) (reply interface{}, err error) { itmLocks := guard.lockItems(lockIDs) @@ -111,7 +125,8 @@ func (guard *GuardianLock) Guard(handler func() (interface{}, error), timeout ti case reply = <-rplyChan: } } - guard.unlockItems(itmLocks) + + itmLocks.unlock() return } @@ -129,7 +144,7 @@ func (guard *GuardianLock) GuardIDs(timeout time.Duration, lockIDs ...string) { // UnguardTimed attempts to unlock a set of locks based on their locksUUID func (guard *GuardianLock) UnguardIDs(lockIDs ...string) { - var itmLocks []*itemLock + var itmLocks itemLocks guard.RLock() for _, lockID := range lockIDs { var itmLock *itemLock @@ -139,6 +154,6 @@ func (guard *GuardianLock) UnguardIDs(lockIDs ...string) { } } guard.RUnlock() - guard.unlockItems(itmLocks) + itmLocks.unlock() return } diff --git a/guardian/guardian_test.go b/guardian/guardian_test.go index 7a21d0eea..5a52f9a84 100644 --- a/guardian/guardian_test.go +++ b/guardian/guardian_test.go @@ -87,7 +87,10 @@ func TestGuardianTimeout(t *testing.T) { } func TestGuardianGuardIDs(t *testing.T) { + + //lock with 3 keys lockIDs := []string{"test1", "test2", "test3"} + // make sure the keys are not in guardian before lock Guardian.RLock() for _, lockID := range lockIDs { if _, hasKey := Guardian.locksMap[lockID]; hasKey { @@ -95,6 +98,8 @@ func TestGuardianGuardIDs(t *testing.T) { } } Guardian.RUnlock() + + // lock 3 items tStart := time.Now() lockDur := 2 * time.Millisecond Guardian.GuardIDs(lockDur, lockIDs...) @@ -107,8 +112,13 @@ func TestGuardianGuardIDs(t *testing.T) { } } Guardian.RUnlock() - go Guardian.GuardIDs(time.Duration(1*time.Millisecond), lockIDs[1:]...) // to test counter - time.Sleep(20 * time.Microsecond) // give time for goroutine to lock + secLockDur := time.Duration(1 * time.Millisecond) + + // second lock to test counter + go Guardian.GuardIDs(secLockDur, lockIDs[1:]...) + time.Sleep(20 * time.Microsecond) // give time for goroutine to lock + + // check if counters were properly increased Guardian.RLock() lkID := lockIDs[0] eCnt := int64(1) @@ -125,18 +135,31 @@ func TestGuardianGuardIDs(t *testing.T) { t.Errorf("Unexpected counter: %d for itmLock with id %s", cnt, lkID) } lkID = lockIDs[2] - eCnt = int64(2) + eCnt = int64(1) // we did not manage to increase it yet since it did not pass first lock if itmLock, hasKey := Guardian.locksMap[lkID]; !hasKey { t.Errorf("Cannot find lock for lockID: %s", lkID) } else if cnt := atomic.LoadInt64(&itmLock.cnt); cnt != eCnt { t.Errorf("Unexpected counter: %d for itmLock with id %s", cnt, lkID) } Guardian.RUnlock() + + time.Sleep(lockDur + secLockDur + time.Millisecond) // give time to unlock before proceeding + + // make sure all counters were removed + for _, lockID := range lockIDs { + if _, hasKey := Guardian.locksMap[lockID]; hasKey { + t.Errorf("Unexpected lockID found: %s", lockID) + } + } + + // test lock without timer Guardian.GuardIDs(0, lockIDs...) if totalLockDur := time.Now().Sub(tStart); totalLockDur < lockDur { t.Errorf("Lock duration too small") } time.Sleep(time.Duration(30) * time.Millisecond) + + // making sure the items stay locked Guardian.RLock() if len(Guardian.locksMap) != 3 { t.Errorf("locksMap should be have 3 elements, have: %+v", Guardian.locksMap) @@ -149,8 +172,11 @@ func TestGuardianGuardIDs(t *testing.T) { } } Guardian.RUnlock() + Guardian.UnguardIDs(lockIDs...) time.Sleep(time.Duration(50) * time.Millisecond) + + // make sure items were unlocked Guardian.RLock() if len(Guardian.locksMap) != 0 { t.Errorf("locksMap should have 0 elements, has: %+v", Guardian.locksMap)