repository: Fix potential race condition

If we crash between compact_segments() and write_index() and the
transaction deletes objects that are newer than the current index
might become undeleted.
This commit is contained in:
Jonas Borgström 2014-02-21 20:20:17 +01:00
parent bd22bc8cb2
commit 6425d16aa8
2 changed files with 53 additions and 63 deletions

View File

@ -6,7 +6,6 @@ import os
import shutil import shutil
import struct import struct
import sys import sys
import time
from zlib import crc32 from zlib import crc32
from .hashindex import NSIndex from .hashindex import NSIndex
@ -87,15 +86,17 @@ class Repository(object):
def get_transaction_id(self): def get_transaction_id(self):
index_transaction_id = self.get_index_transaction_id() index_transaction_id = self.get_index_transaction_id()
segments_transaction_id = self.io.get_segments_transaction_id() segments_transaction_id = self.io.get_segments_transaction_id()
if index_transaction_id is not None and segments_transaction_id is None:
raise self.CheckNeeded(self.path)
# Attempt to automatically rebuild index if we crashed between commit # Attempt to automatically rebuild index if we crashed between commit
# tag write and index save # tag write and index save
if (index_transaction_id if index_transaction_id is not None else -1) < (segments_transaction_id if segments_transaction_id is not None else -1):
self.replay_segments(index_transaction_id, segments_transaction_id)
index_transaction_id = self.get_index_transaction_id()
if index_transaction_id != segments_transaction_id: if index_transaction_id != segments_transaction_id:
raise self.CheckNeeded(self.path) if index_transaction_id is not None and index_transaction_id > segments_transaction_id:
return index_transaction_id replay_from = None
else:
replay_from = index_transaction_id
self.replay_segments(replay_from, segments_transaction_id)
return self.get_index_transaction_id()
def open(self, path): def open(self, path):
self.path = path self.path = path
@ -175,21 +176,24 @@ class Repository(object):
""" """
if not self.compact: if not self.compact:
return return
index_transaction_id = self.get_index_transaction_id()
def lookup(tag, key):
return tag == TAG_PUT and self.index.get(key, (-1, -1))[0] == segment
segments = self.segments segments = self.segments
for segment in sorted(self.compact): for segment in sorted(self.compact):
if segments[segment] > 0: if self.io.segment_exists(segment):
for tag, key, data in self.io.iter_objects(segment, lookup, include_data=True): for tag, key, data in self.io.iter_objects(segment, include_data=True):
new_segment, offset = self.io.write_put(key, data) if tag == TAG_PUT and self.index.get(key, (-1, -1))[0] == segment:
self.index[key] = new_segment, offset new_segment, offset = self.io.write_put(key, data)
segments.setdefault(new_segment, 0) self.index[key] = new_segment, offset
segments[new_segment] += 1 segments.setdefault(new_segment, 0)
segments[segment] -= 1 segments[new_segment] += 1
segments[segment] -= 1
elif tag == TAG_DELETE:
if index_transaction_id is None or segment > index_transaction_id:
self.io.write_delete(key)
assert segments[segment] == 0 assert segments[segment] == 0
self.io.write_commit() self.io.write_commit()
for segment in self.compact: for segment in sorted(self.compact):
assert self.segments.pop(segment) == 0 assert self.segments.pop(segment) == 0
self.io.delete_segment(segment) self.io.delete_segment(segment)
self.compact = set() self.compact = set()
@ -215,10 +219,10 @@ class Repository(object):
elif tag == TAG_DELETE: elif tag == TAG_DELETE:
try: try:
s, _ = self.index.pop(key) s, _ = self.index.pop(key)
self.segments[s] -= 1
self.compact.add(s)
except KeyError: except KeyError:
raise self.CheckNeeded(self.path) pass
self.segments[s] -= 1
self.compact.add(s)
self.compact.add(segment) self.compact.add(segment)
elif tag == TAG_COMMIT: elif tag == TAG_COMMIT:
continue continue
@ -246,21 +250,16 @@ class Repository(object):
assert not self._active_txn assert not self._active_txn
report_progress('Starting repository check...') report_progress('Starting repository check...')
index_transaction_id = self.get_index_transaction_id() try:
segments_transaction_id = self.io.get_segments_transaction_id() transaction_id = self.get_transaction_id()
if index_transaction_id is None and segments_transaction_id is None:
return True
if segments_transaction_id is not None:
transaction_id = segments_transaction_id
else:
transaction_id = index_transaction_id
self.get_index(None)
if index_transaction_id == segments_transaction_id:
current_index = self.get_read_only_index(transaction_id) current_index = self.get_read_only_index(transaction_id)
else: except Exception:
transaction_id = self.io.get_segments_transaction_id()
current_index = None current_index = None
report_progress('No suitable index found', error=True) if transaction_id is None:
transaction_id = self.get_index_transaction_id()
segments_transaction_id = self.io.get_segments_transaction_id()
self.get_index(None)
for segment, filename in self.io.segment_iterator(): for segment, filename in self.io.segment_iterator():
if segment > transaction_id: if segment > transaction_id:
continue continue
@ -302,7 +301,7 @@ class Repository(object):
self.io.write_commit() self.io.write_commit()
self.io.close_segment() self.io.close_segment()
if current_index and not repair: if current_index and not repair:
if len(current_index) != len(self.index) and False: if len(current_index) != len(self.index):
report_progress('Index object count mismatch. {} != {}'.format(len(current_index), len(self.index)), error=True) report_progress('Index object count mismatch. {} != {}'.format(len(current_index), len(self.index)), error=True)
elif current_index: elif current_index:
for key, value in self.index.iteritems(): for key, value in self.index.iteritems():
@ -369,13 +368,13 @@ class Repository(object):
self.get_index(self.get_transaction_id()) self.get_index(self.get_transaction_id())
try: try:
segment, offset = self.index.pop(id) segment, offset = self.index.pop(id)
self.segments[segment] -= 1
self.compact.add(segment)
segment = self.io.write_delete(id)
self.compact.add(segment)
self.segments.setdefault(segment, 0)
except KeyError: except KeyError:
raise self.DoesNotExist(self.path) raise self.DoesNotExist(self.path)
self.segments[segment] -= 1
self.compact.add(segment)
segment = self.io.write_delete(id)
self.compact.add(segment)
self.segments.setdefault(segment, 0)
def preload(self, ids): def preload(self, ids):
"""Preload objects (only applies to remote repositories """Preload objects (only applies to remote repositories
@ -479,7 +478,10 @@ class LoggedIO(object):
except OSError: except OSError:
pass pass
def iter_objects(self, segment, lookup=None, include_data=False): def segment_exists(self, segment):
return os.path.exists(self.segment_filename(segment))
def iter_objects(self, segment, include_data=False):
fd = self.get_fd(segment) fd = self.get_fd(segment)
fd.seek(0) fd.seek(0)
if fd.read(8) != MAGIC: if fd.read(8) != MAGIC:
@ -498,11 +500,10 @@ class LoggedIO(object):
key = None key = None
if tag in (TAG_PUT, TAG_DELETE): if tag in (TAG_PUT, TAG_DELETE):
key = rest[:32] key = rest[:32]
if not lookup or lookup(tag, key): if include_data:
if include_data: yield tag, key, rest[32:]
yield tag, key, rest[32:] else:
else: yield tag, key, offset
yield tag, key, offset
offset += size offset += size
header = fd.read(self.header_fmt.size) header = fd.read(self.header_fmt.size)

View File

@ -115,9 +115,11 @@ class RepositoryCommitTestCase(RepositoryTestCaseBase):
def add_keys(self): def add_keys(self):
self.repository.put(b'00000000000000000000000000000000', b'foo') self.repository.put(b'00000000000000000000000000000000', b'foo')
self.repository.put(b'00000000000000000000000000000001', b'bar') self.repository.put(b'00000000000000000000000000000001', b'bar')
self.repository.put(b'00000000000000000000000000000003', b'bar')
self.repository.commit() self.repository.commit()
self.repository.put(b'00000000000000000000000000000001', b'bar2') self.repository.put(b'00000000000000000000000000000001', b'bar2')
self.repository.put(b'00000000000000000000000000000002', b'boo') self.repository.put(b'00000000000000000000000000000002', b'boo')
self.repository.delete(b'00000000000000000000000000000003')
def test_replay_of_missing_index(self): def test_replay_of_missing_index(self):
self.add_keys() self.add_keys()
@ -125,7 +127,7 @@ class RepositoryCommitTestCase(RepositoryTestCaseBase):
if name.startswith('index.'): if name.startswith('index.'):
os.unlink(os.path.join(self.repository.path, name)) os.unlink(os.path.join(self.repository.path, name))
self.reopen() self.reopen()
self.assert_equal(len(self.repository), 2) self.assert_equal(len(self.repository), 3)
self.assert_equal(self.repository.check(), True) self.assert_equal(self.repository.check(), True)
def test_crash_before_compact_segments(self): def test_crash_before_compact_segments(self):
@ -174,7 +176,6 @@ class RepositoryCommitTestCase(RepositoryTestCaseBase):
self.assert_equal(len(self.repository), 3) self.assert_equal(len(self.repository), 3)
class RepositoryCheckTestCase(RepositoryTestCaseBase): class RepositoryCheckTestCase(RepositoryTestCaseBase):
def list_indices(self): def list_indices(self):
@ -249,10 +250,6 @@ class RepositoryCheckTestCase(RepositoryTestCaseBase):
def test_repair_missing_commit_segment(self): def test_repair_missing_commit_segment(self):
self.add_objects([[1, 2, 3], [4, 5, 6]]) self.add_objects([[1, 2, 3], [4, 5, 6]])
self.delete_segment(1) self.delete_segment(1)
self.assert_raises(Repository.CheckNeeded, lambda: self.get_objects(4))
self.check(status=False)
self.assert_raises(Repository.CheckNeeded, lambda: self.get_objects(4))
self.check(repair=True, status=True)
self.assert_raises(Repository.DoesNotExist, lambda: self.get_objects(4)) self.assert_raises(Repository.DoesNotExist, lambda: self.get_objects(4))
self.assert_equal(set([1, 2, 3]), self.list_objects()) self.assert_equal(set([1, 2, 3]), self.list_objects())
@ -261,11 +258,9 @@ class RepositoryCheckTestCase(RepositoryTestCaseBase):
with open(os.path.join(self.tmppath, 'repository', 'data', '0', '1'), 'r+b') as fd: with open(os.path.join(self.tmppath, 'repository', 'data', '0', '1'), 'r+b') as fd:
fd.seek(-1, os.SEEK_END) fd.seek(-1, os.SEEK_END)
fd.write(b'X') fd.write(b'X')
self.assert_raises(Repository.CheckNeeded, lambda: self.get_objects(4))
self.check(status=False)
self.check(repair=True, status=True)
self.get_objects(3)
self.assert_raises(Repository.DoesNotExist, lambda: self.get_objects(4)) self.assert_raises(Repository.DoesNotExist, lambda: self.get_objects(4))
self.check(status=True)
self.get_objects(3)
self.assert_equal(set([1, 2, 3]), self.list_objects()) self.assert_equal(set([1, 2, 3]), self.list_objects())
def test_repair_no_commits(self): def test_repair_no_commits(self):
@ -286,8 +281,6 @@ class RepositoryCheckTestCase(RepositoryTestCaseBase):
def test_repair_missing_index(self): def test_repair_missing_index(self):
self.add_objects([[1, 2, 3], [4, 5, 6]]) self.add_objects([[1, 2, 3], [4, 5, 6]])
self.delete_index() self.delete_index()
self.check(status=False)
self.check(repair=True, status=True)
self.check(status=True) self.check(status=True)
self.get_objects(4) self.get_objects(4)
self.assert_equal(set([1, 2, 3, 4, 5, 6]), self.list_objects()) self.assert_equal(set([1, 2, 3, 4, 5, 6]), self.list_objects())
@ -296,12 +289,8 @@ class RepositoryCheckTestCase(RepositoryTestCaseBase):
self.add_objects([[1, 2, 3], [4, 5, 6]]) self.add_objects([[1, 2, 3], [4, 5, 6]])
self.assert_equal(self.list_indices(), ['index.1']) self.assert_equal(self.list_indices(), ['index.1'])
self.rename_index('index.100') self.rename_index('index.100')
self.assert_equal(self.list_indices(), ['index.100'])
self.assert_raises(Repository.CheckNeeded, lambda: self.get_objects(4))
self.check(status=False)
self.check(repair=True, status=True)
self.assert_equal(self.list_indices(), ['index.1'])
self.check(status=True) self.check(status=True)
self.assert_equal(self.list_indices(), ['index.1'])
self.get_objects(4) self.get_objects(4)
self.assert_equal(set([1, 2, 3, 4, 5, 6]), self.list_objects()) self.assert_equal(set([1, 2, 3, 4, 5, 6]), self.list_objects())