guillaume-be / bpe-example Goto Github PK
View Code? Open in Web Editor NEWSupporting code for post on Rust implementation of Byte pair Encoding
License: Apache License 2.0
Supporting code for post on Rust implementation of Byte pair Encoding
License: Apache License 2.0
Recently, I'm working on my NLP homework on BPE Algorithm. I came up with an efficient algorithm a few days ago and saw your article yesterday. I found that this method is very similar to the Linked List Version mentioned in the article, but it is more memory-efficient.
I implemented it in Python, so I'm currently unable to directly compare the running speed between the two. However, in terms of principle, this implementation has lower memory overhead. And I believe it should have at least comparable running speed. But I have compared my version with huggingface BPE Trainer.
The experimental setting is to process a Chinese corpus with a size of
2547540
words and generate a target subword vocabulary of size1e4
, with a minimum frequency of10
.
Huggingface BPE Trainer cost:
5.31s user 1.57s system 132% cpu 5.184 total
os.environ["TOKENIZERS_PARALLELISM"] = "false"
And my implementation cost:
2.73s user 0.04s system 99% cpu 2.778 total
And I tried to test with a Chinese WIKI corpus of 1GB in size. The Huggingface version could not complete the tokenization task within 12 hours, while my version can complete it in less than 10 minutes.
Here, I'll list some obvious differences, and attach my code.
Dict[Tuple[str, str], List[int]]
. Therefore, I do not push each position's Pair into the priority queue separately. Instead, I push each category of Pairs into the priority queue. After retrieving them from the queue, I directly merge them according to the position list stored in the dictionary.seg_status
.i = 3
pre_sub_word_len = seg_status[i-1]
pre_sub_word_start = i - pre_sub_word_len
cur_sub_word_len = seg_status[i]
nxt_sub_word_start = i + cur_sub_word_len
nxt_sub_word_len = seg_status[nxt_sub_word_start]
find best pair
step. As long as the popped pair's freq in the queue does not match it's actual freq at the moment, it will be pushed back into the queue with the current freq until they become equal.Here is my python Implementation ( Use the BPETrainer.train_from_file
function to learn a subword vocabulary with subword frequency):
from array import array
from collections import defaultdict
from itertools import chain
from copy import deepcopy
import heapq
from tqdm import tqdm
from typing import Union, List, Dict
try:
from itertools import pairwise
print("Using itertools.pairwise")
except:
print("Using custom pairwise")
from itertools import tee
def pairwise(iterable):
"""s -> (s0,s1), (s1,s2), (s2, s3), ..."""
a, b = tee(iterable)
next(b, None)
return zip(a, b)
def preproc_idx(full_indices, word_a, word_b):
# 获取indices
if word_a == word_b: # Same word, merge from back to front
# For example, for the case of "1 0 0 0" merging 0 and 0, merging from back to front, get "1 0 00"
len_a = len(word_a)
indices = full_indices[::-1]
new_indices = [indices[0]]
for idx in indices[1:]:
if new_indices[-1] - idx != len_a:
new_indices.append(idx)
indices = array('I', new_indices[::-1])
return indices
else:
return full_indices
class BPETrainer:
def __init__(self,vocab_size, min_freq: int = 10, compress_threshold: float = 0, single_char: bool=True) -> None:
self.corpus = ""
self.vocab_size = vocab_size
self.min_freq = min_freq
self.compress_threshold = compress_threshold
self.corpus = self.word_pair_pos = self.word_pair_len = None
self.seg_status= self.vocab = self.pair_freq_queue = self.word_count = None
self.single_char = single_char
def train_from_file(self, path: Union[List[str], str], verbose: bool = False) -> Dict[str, int]:
self.init_word_pair(self.load_file(path, verbose), self.min_freq, verbose)
self.epoch = 0
while len(self.vocab) < self.vocab_size:
comb, freq = self.most_frequent_combination()
if freq <= self.min_freq:
break
self.merge_word(comb, freq)
if verbose and (freq > int(1e5) or self.epoch % 50 == 0):
self.log(self.epoch, comb, freq)
word_comb = "".join(comb)
self.vocab[word_comb] = self.word_count[word_comb] = freq
self.epoch += 1
if verbose:
print(f"Final Vocab ({'' if self.single_char else 'Not'} Including Single Characters) Size: {len(self.vocab)} in {self.epoch} Epoch")
return BPE(self.vocab)
@classmethod
def load_file(cls, path: Union[List[str], str], verbose: bool):
path = path if isinstance(path ,list) else [path]
pbar = tqdm(path) if verbose else path
for p in pbar:
with open(p, "r") as f:
yield f.read().replace("\n", "#")
def init_word_pair(self, corpus_iter, min_freq: int, verbose: bool):
corpus = "#".join(chain([''], corpus_iter, ['']))
assert 4294967295 > len(corpus)
word_pair_pos = defaultdict(lambda: array('I'))
vocab = defaultdict(int)
if verbose:
print("Initing pair count...")
pbar = enumerate(pairwise(corpus), start=0)
pbar = pbar if not verbose else tqdm(pbar, total=len(corpus)-1)
if self.single_char:
for i, (pre_char, nxt_char) in pbar:
if pre_char == "#":
continue
vocab[pre_char] += 1
if nxt_char == "#":
continue
word_pair_pos[(pre_char, nxt_char)].append(i)
else:
for i, (pre_char, nxt_char) in pbar:
if pre_char == "#" or nxt_char == "#":
continue
word_pair_pos[(pre_char, nxt_char)].append(i)
self.word_pair_pos = {
word_pair: indices
for word_pair, indices in word_pair_pos.items()
if len(indices) >= min_freq
}
self.word_pair_len = {
word_pair: len(indices)
for word_pair, indices in self.word_pair_pos.items()
}
self.pair_freq_queue = [(-freq, word_pair) for word_pair, freq in self.word_pair_len.items()]
heapq.heapify(self.pair_freq_queue)
self.vocab = vocab
self.corpus = corpus
self.word_count = deepcopy(self.vocab)
self.seg_status = array('B', [1] * (len(self.corpus) + 2))
if verbose:
print("init finish!")
def log(self, epoch, comb, freq):
print(f"epoch: {epoch}\tcomb: {' + '.join(comb)}\tfreq: {freq}")
def most_frequent_combination(self):
while len(self.pair_freq_queue) > 0:
cached_freq, comb = heapq.heappop(self.pair_freq_queue)
ground_freq = self.word_pair_len[comb]
cached_freq = -cached_freq
if cached_freq == ground_freq:
return comb, cached_freq
elif ground_freq > self.min_freq:
heapq.heappush(self.pair_freq_queue, (-ground_freq, comb))
if ground_freq / len(self.word_pair_pos[comb]) < self.compress_threshold:
self.word_pair_pos[comb] = self.compress_indices(*comb)
else:
self.word_pair_len.pop(comb, None)
self.word_pair_pos.pop(comb, None)
return None, 0
def compress_indices(self, word_a, word_b, indices=None):
seg = self.seg_status
len_a, len_b = len(word_a), len(word_b)
indices = indices or self.word_pair_pos[(word_a, word_b)]
return array('I', (
i for i in indices
if seg[i] == len_a and seg[i + len_a] == len_b
))
def merge_word(self, comb, freq):
word_a, word_b = comb
word_comb = word_a + word_b
len_a, len_b, len_comb = len(word_a), len(word_b), len(word_comb)
seg_status = self.seg_status
word_pair_v2 = self.word_pair_pos
corpus = self.corpus
indices = word_pair_v2[comb]
if len(indices) > freq:
indices = self.compress_indices(word_a, word_b, indices)
indices = preproc_idx(indices, word_a, word_b)
self.word_count[word_a] -= len(indices)
self.word_count[word_b] -= len(indices)
new_pairs = defaultdict(list)
for i in indices:
# using seg_status to find the previous and next word
pre_end, nxt_start = i, i + len_comb
nxt_end = seg_status[nxt_start] + nxt_start
pre_start = i - seg_status[i - 1]
pre_word, nxt_word = corpus[pre_start: pre_end], corpus[nxt_start: nxt_end]
if pre_word != "#":
try:
self.word_pair_len[pre_word, word_a] -= 1
except KeyError:
pass
if pre_word == word_b and self.get_pre_word(pre_start) == word_a:
# previous 2 words is the same as word_a and word_b
# For example“1 0 1 0” to “10 10”
new_pairs[word_comb, word_comb].append(pre_start - len_a)
else:
new_pairs[pre_word, word_comb].append(pre_start)
if nxt_word != "#":
if not (nxt_word == word_a and self.get_nxt_word(nxt_end) == word_b):
# next 2 words is not the same as word_a and word_b
try:
self.word_pair_len[word_b, nxt_word] -= 1
except KeyError:
pass
new_pairs[word_comb, nxt_word].append(i)
self.update(new_pairs)
if len_b == 1:
for i in indices:
seg_status[i] = len_comb
seg_status[i + len_a] = len_comb
else:
for i in indices:
seg_status[i] = len_comb
seg_status[i + len_a] = 0
seg_status[i + len_comb - 1] = len_comb
word_pair_v2.pop(comb)
self.word_pair_len.pop(comb)
if self.word_count[word_a] <= 0:
p = self.vocab.pop(word_a, None)
if p:
print(f"\tremove {word_a}")
if word_b != word_a and self.word_count[word_b] <= 0:
p = self.vocab.pop(word_b, None)
if p:
print(f"\tremove {word_b}")
def update(self, new_pairs):
for k, v in new_pairs.items():
if len(v) >= self.min_freq:
data = self.word_pair_pos[k] = array("I", v)
freq = self.word_pair_len[k] = len(data)
heapq.heappush(self.pair_freq_queue, (-freq, k))
def get_pre_word(self, init_pos):
pos = init_pos - self.seg_status[init_pos - 1]
return self.corpus[pos: init_pos]
def get_nxt_word(self, init_pos):
pos = init_pos + self.seg_status[init_pos]
return self.corpus[init_pos: pos]
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.