diff --git a/guardian/guardian.go b/guardian/guardian.go index 99d4dcde7..2858c173e 100644 --- a/guardian/guardian.go +++ b/guardian/guardian.go @@ -19,6 +19,7 @@ package guardian import ( "sync" + "sync/atomic" "time" ) @@ -32,14 +33,14 @@ func newItemLock(keyID string) *itemLock { // itemLock represents one lock with key autodestroy type itemLock struct { keyID string // store it so we know what to destroy - cnt int + cnt int64 sync.Mutex } // unlock() executes combined lock with autoremoving lock from Guardian func (il *itemLock) unlock() { - il.cnt-- - if il.cnt == 0 { // last lock in the queue + atomic.AddInt64(&il.cnt, -1) + if il.count() == 0 { // last lock in the queue Guardian.Lock() delete(Guardian.locksMap, il.keyID) Guardian.Unlock() @@ -47,6 +48,10 @@ func (il *itemLock) unlock() { il.Unlock() } +func (il *itemLock) count() int64 { + return atomic.LoadInt64(&il.cnt) +} + // GuardianLock is an optimized locking system per locking key type GuardianLock struct { locksMap map[string]*itemLock @@ -64,7 +69,7 @@ func (guard *GuardianLock) lockItems(lockIDs []string) (itmLocks []*itemLock) { itmLock = newItemLock(lockID) Guardian.locksMap[lockID] = itmLock } - itmLock.cnt++ + atomic.AddInt64(&itmLock.cnt, 1) itmLocks = append(itmLocks, itmLock) } guard.Unlock() diff --git a/guardian/guardian_test.go b/guardian/guardian_test.go index 98b7ba8b8..76495bd85 100644 --- a/guardian/guardian_test.go +++ b/guardian/guardian_test.go @@ -94,7 +94,7 @@ func TestGuardianGuardIDs(t *testing.T) { for _, lockID := range lockIDs { if itmLock, hasKey := Guardian.locksMap[lockID]; !hasKey { t.Errorf("Cannot find lock for lockID: %s", lockID) - } else if itmLock.cnt != 1 { + } else if itmLock.count() != 1 { t.Errorf("Unexpected itmLock found: %+v", itmLock) } } @@ -102,17 +102,17 @@ func TestGuardianGuardIDs(t *testing.T) { time.Sleep(20 * time.Microsecond) // give time for goroutine to lock if itmLock, hasKey := Guardian.locksMap["test1"]; !hasKey { t.Errorf("Cannot find lock for lockID: %s", "test1") - } else if itmLock.cnt != 1 { + } else if itmLock.count() != 1 { t.Errorf("Unexpected itmLock found: %+v", itmLock) } if itmLock, hasKey := Guardian.locksMap["test2"]; !hasKey { t.Errorf("Cannot find lock for lockID: %s", "test2") - } else if itmLock.cnt != 2 { + } else if itmLock.count() != 2 { t.Errorf("Unexpected itmLock found: %+v", itmLock) } if itmLock, hasKey := Guardian.locksMap["test3"]; !hasKey { t.Errorf("Cannot find lock for lockID: %s", "test3") - } else if itmLock.cnt != 2 { + } else if itmLock.count() != 2 { t.Errorf("Unexpected itmLock found: %+v", itmLock) } Guardian.GuardIDs(0, lockIDs...) @@ -124,15 +124,15 @@ func TestGuardianGuardIDs(t *testing.T) { t.Errorf("locksMap should be have 3 elements, have: %+v", Guardian.locksMap) } else if itmLock, hasKey := Guardian.locksMap["test1"]; !hasKey { t.Errorf("Cannot find lock for lockID: %s", "test1") - } else if itmLock.cnt != 1 { + } else if itmLock.count() != 1 { t.Errorf("Unexpected itmLock found: %+v", itmLock) } else if itmLock, hasKey := Guardian.locksMap["test2"]; !hasKey { t.Errorf("Cannot find lock for lockID: %s", "test2") - } else if itmLock.cnt != 1 { + } else if itmLock.count() != 1 { t.Errorf("Unexpected itmLock found: %+v", itmLock) } else if itmLock, hasKey := Guardian.locksMap["test3"]; !hasKey { t.Errorf("Cannot find lock for lockID: %s", "test2") - } else if itmLock.cnt != 1 { + } else if itmLock.count() != 1 { t.Errorf("Unexpected itmLock found: %+v", itmLock) } Guardian.UnguardIDs(lockIDs...)