diff --git a/trie.py b/trie.py index 93702fe2ab..18fe25b337 100644 --- a/trie.py +++ b/trie.py @@ -12,6 +12,25 @@ class DB(): def put(self,key,value): return self.db.Put(key,value) def delete(self,key): return self.db.Delete(key) +def hexarraykey_to_bin(key): + term = 1 if key[-1] == 16 else 0 + if term: key2 = key[:-1] + oddlen = len(key) % 2 + flags = 2 * term + oddlen + if oddlen: key = [flags] + key + else: key = [flags,0] + key + o = '' + for i in range(0,len(key),2): + o += chr(16 * key[i] + key[i+1]) + return o + +def bin_to_hexarraykey(bindata): + o = ['0123456789abcdef'.find(x) for x in key[1:].encode('hex')] + if o[0] >= 2: o.append(16) + if o[0] % 2 == 1: o = o[1:] + else: o = o[2:] + return o + databases = {} class Trie(): @@ -21,19 +40,6 @@ class Trie(): if dbfile not in databases: databases[dbfile] = DB(dbfile) self.db = databases[dbfile] - - def __encode_key(self,key): - term = 1 if key[-1] == 16 else 0 - oddlen = (len(key) - term) % 2 - prefix = ('0' if oddlen else '') - main = ''.join(['0123456789abcdef'[x] for x in key[:len(key)-term]]) - return chr(2 * term + oddlen) + (prefix+main).decode('hex') - - def __decode_key(self,key): - o = ['0123456789abcdef'.find(x) for x in key[1:].encode('hex')] - if key[0] == '\x01' or key[0] == '\x03': o = o[1:] - if key[0] == '\x02' or key[0] == '\x03': o.append(16) - return o def __get_state(self,node,key): if self.debug: print 'nk',node.encode('hex'),key @@ -45,7 +51,7 @@ class Trie(): raise Exception("node not found in database") elif len(curnode) == 2: (k2,v2) = curnode - k2 = self.__decode_key(k2) + k2 = hexarraykey_to_bin(k2) if len(key) >= len(k2) and k2 == key[:len(k2)]: return self.__get_state(v2,key[len(k2):]) else: @@ -69,7 +75,7 @@ class Trie(): return value else: if not node: - newnode = [ self.__encode_key(key), value ] + newnode = [ hexarraykey_to_bin(key), value ] return self.__put(newnode) curnode = rlp.decode(self.db.get(node)) if self.debug: print 'icn', curnode @@ -77,9 +83,9 @@ class Trie(): raise Exception("node not found in database") if len(curnode) == 2: (k2, v2) = curnode - k2 = self.__decode_key(k2) + k2 = hexarraykey_to_bin(k2) if key == k2: - newnode = [ self.__encode_key(key), value ] + newnode = [ hexarraykey_to_bin(key), value ] return self.__put(newnode) else: i = 0 @@ -96,7 +102,7 @@ class Trie(): if i == 0: return newhash3 else: - newnode4 = [ self.__encode_key(key[:i]), newhash3 ] + newnode4 = [ hexarraykey_to_bin(key[:i]), newhash3 ] return self.__put(newnode4) else: newnode = [ curnode[i] for i in range(17) ] @@ -114,15 +120,15 @@ class Trie(): if self.debug: print 'dcn', curnode if len(curnode) == 2: (k2, v2) = curnode - k2 = self.__decode_key(k2) + k2 = hexarraykey_to_bin(k2) if key == k2: return '' elif key[:len(k2)] == k2: newhash = self.__delete_state(v2,key[len(k2):]) childnode = rlp.decode(self.db.get(newhash)) if len(childnode) == 2: - newkey = k2 + self.__decode_key(childnode[0]) - newnode = [ self.__encode_key(newkey), childnode[1] ] + newkey = k2 + hexarraykey_to_bin(childnode[0]) + newnode = [ hexarraykey_to_bin(newkey), childnode[1] ] else: newnode = [ curnode[0], newhash ] return self.__put(newnode) @@ -142,8 +148,8 @@ class Trie(): if len(childnode) == 17: newnode2 = [ key[0], newnode[onlynode] ] elif len(childnode) == 2: - newkey = [onlynode] + self.__decode_key(childnode[0]) - newnode2 = [ self.__encode_key(newkey), childnode[1] ] + newkey = [onlynode] + hexarraykey_to_bin(childnode[0]) + newnode2 = [ hexarraykey_to_bin(newkey), childnode[1] ] else: newnode2 = newnode return self.__put(newnode2) @@ -154,7 +160,7 @@ class Trie(): if not curnode: raise Exception("node not found in database") if len(curnode) == 2: - key = self.__decode_key(curnode[0]) + key = hexarraykey_to_bin(curnode[0]) if key[-1] == 16: return 1 else: return self.__get_size(curnode[1]) elif len(curnode) == 17: @@ -170,15 +176,15 @@ class Trie(): if not curnode: raise Exception("node not found in database") if len(curnode) == 2: - lkey = self.__decode_key(curnode[0]) + lkey = hexarraykey_to_bin(curnode[0]) o = {} if lkey[-1] == 16: o[curnode[0]] = curnode[1] else: d = self.__to_dict(curnode[1]) for v in d: - subkey = self.__decode_key(v) - totalkey = self.__encode_key(lkey+subkey) + subkey = hexarraykey_to_bin(v) + totalkey = hexarraykey_to_bin(lkey+subkey) o[totalkey] = d[v] return o elif len(curnode) == 17: @@ -186,8 +192,8 @@ class Trie(): for i in range(16): d = self.__to_dict(curnode[i]) for v in d: - subkey = self.__decode_key(v) - totalkey = self.__encode_key([i] + subkey) + subkey = hexarraykey_to_bin(v) + totalkey = hexarraykey_to_bin([i] + subkey) o[totalkey] = d[v] if curnode[16]: o[chr(16)] = curnode[16] return o @@ -198,7 +204,7 @@ class Trie(): d = self.__to_dict(self.root) o = {} for v in d: - v2 = ''.join(['0123456789abcdef'[x] for x in self.__decode_key(v)[:-1]]) + v2 = ''.join(['0123456789abcdef'[x] for x in hexarraykey_to_bin(v)[:-1]]) if not as_hex: v2 = v2.decode('hex') o[v2] = d[v] return o