GithubHelp home page GithubHelp logo

bpe-example's People

Contributors

guillaume-be avatar

Stargazers

 avatar

Watchers

 avatar  avatar

Forkers

ralfgermany

bpe-example's Issues

A More Memory Efficient (Maybe Faster) Implementation For the Linked List Version

Recently, I'm working on my NLP homework on BPE Algorithm. I came up with an efficient algorithm a few days ago and saw your article yesterday. I found that this method is very similar to the Linked List Version mentioned in the article, but it is more memory-efficient.

I implemented it in Python, so I'm currently unable to directly compare the running speed between the two. However, in terms of principle, this implementation has lower memory overhead. And I believe it should have at least comparable running speed. But I have compared my version with huggingface BPE Trainer.

The experimental setting is to process a Chinese corpus with a size of 2547540 words and generate a target subword vocabulary of size 1e4, with a minimum frequency of 10.

Huggingface BPE Trainer cost:

5.31s user 1.57s system 132% cpu 5.184 total
os.environ["TOKENIZERS_PARALLELISM"] = "false"

And my implementation cost:

2.73s user 0.04s system 99% cpu 2.778 total

And I tried to test with a Chinese WIKI corpus of 1GB in size. The Huggingface version could not complete the tokenization task within 12 hours, while my version can complete it in less than 10 minutes.

Here, I'll list some obvious differences, and attach my code.

  1. I store all the start position of every kind of pair in a dict of type Dict[Tuple[str, str], List[int]]. Therefore, I do not push each position's Pair into the priority queue separately. Instead, I push each category of Pairs into the priority queue. After retrieving them from the queue, I directly merge them according to the position list stored in the dictionary.
  2. Not all starting positions stored in my dictionary are valid. This approach is similar to your method, but the validation method is somewhat different. I use a uint8 array of the same length as the sequence to validate it. I call this array seg_status.
    For example, for the sequence that is split into [app, l, e], the seg_status is [3,0,3,1,1]. When I obtain the starting position 3 of 'l', I can check it by it's length stored in seg_status and quickly locate the previous sub-word 'app' and the next sub-word 'e' by using the seg_status. It is worth noting that the "seg_status" is initialized to all 1 at the beginning, which enables the query to be valid.Here is the example code:
i = 3
pre_sub_word_len = seg_status[i-1]
pre_sub_word_start = i - pre_sub_word_len

cur_sub_word_len = seg_status[i]

nxt_sub_word_start = i + cur_sub_word_len
nxt_sub_word_len = seg_status[nxt_sub_word_start]
  1. Because not all the position lists of a pair are valid, there will be a problem that the element popped from the priority queue may not have the highest freq. Therefore, I added another dictionary to store the actual count of all pairs. This dictionary is updated during the merge process.
    It should be noted that the position lists will only become shorter during the merging process. Therefore, we can take advantage of this property in find best pair step. As long as the popped pair's freq in the queue does not match it's actual freq at the moment, it will be pushed back into the queue with the current freq until they become equal.
    Moreover, when taking out a pair, we can also check the ratio of its freq to its initial list length (0~1). If it is lower than the threshold set, we can validate and compress the position list in advance to save memory. (This operation can keep the memory usage of the dictionary that stores the position lists of pairs roughly around a fixed value instead of continuously growing.)

Here is my python Implementation ( Use the BPETrainer.train_from_file function to learn a subword vocabulary with subword frequency):

from array import array
from collections import defaultdict
from itertools import chain
from copy import deepcopy
import heapq
from tqdm import tqdm
from typing import Union, List, Dict

try:
    from itertools import pairwise
    print("Using itertools.pairwise")
except:
    print("Using custom pairwise")
    from itertools import tee
    def pairwise(iterable):
        """s -> (s0,s1), (s1,s2), (s2, s3), ..."""
        a, b = tee(iterable)
        next(b, None)
        return zip(a, b)


def preproc_idx(full_indices, word_a, word_b):
    # 获取indices
    if word_a == word_b:  # Same word, merge from back to front
        # For example, for the case of "1 0 0 0" merging 0 and 0, merging from back to front, get "1 0 00"
        len_a = len(word_a)
        indices = full_indices[::-1]
        new_indices = [indices[0]]
        for idx in indices[1:]:
            if new_indices[-1] - idx != len_a:
                new_indices.append(idx)
        indices = array('I', new_indices[::-1])
        return indices
    else:
        return full_indices


class BPETrainer:
    def __init__(self,vocab_size, min_freq: int = 10, compress_threshold: float = 0, single_char: bool=True) -> None:
        self.corpus = ""
        self.vocab_size = vocab_size
        self.min_freq = min_freq
        self.compress_threshold = compress_threshold
        self.corpus = self.word_pair_pos = self.word_pair_len = None
        self.seg_status= self.vocab = self.pair_freq_queue = self.word_count = None
        self.single_char = single_char

    def train_from_file(self, path: Union[List[str], str], verbose: bool = False) -> Dict[str, int]:
        self.init_word_pair(self.load_file(path, verbose), self.min_freq, verbose)

        self.epoch = 0
        while len(self.vocab) < self.vocab_size:
            comb, freq = self.most_frequent_combination()
            if freq <= self.min_freq:
                break
            self.merge_word(comb, freq)
            if verbose and (freq > int(1e5) or self.epoch % 50 == 0):
                self.log(self.epoch, comb, freq)
            word_comb = "".join(comb)
            self.vocab[word_comb] = self.word_count[word_comb] = freq
            self.epoch += 1
        if verbose:
            print(f"Final Vocab ({'' if self.single_char else 'Not'} Including Single Characters) Size: {len(self.vocab)} in {self.epoch} Epoch")
        return BPE(self.vocab)

    @classmethod
    def load_file(cls, path: Union[List[str], str], verbose: bool):
        path = path if isinstance(path ,list) else [path]
        pbar = tqdm(path) if verbose else path
        for p in pbar:
            with open(p, "r") as f:
                yield f.read().replace("\n", "#")
    
    def init_word_pair(self, corpus_iter, min_freq: int, verbose: bool):
        corpus = "#".join(chain([''], corpus_iter, ['']))
        assert 4294967295 > len(corpus)
        word_pair_pos = defaultdict(lambda: array('I'))
        vocab = defaultdict(int)
        if verbose:
            print("Initing pair count...")
        pbar = enumerate(pairwise(corpus), start=0)

        pbar = pbar if not verbose else tqdm(pbar, total=len(corpus)-1)
        if self.single_char:
            for i, (pre_char, nxt_char) in pbar:
                if pre_char == "#":
                    continue
                vocab[pre_char] += 1
                if nxt_char == "#":
                    continue
                word_pair_pos[(pre_char, nxt_char)].append(i)
        else:
            for i, (pre_char, nxt_char) in pbar:
                if pre_char == "#" or nxt_char == "#":
                    continue
                word_pair_pos[(pre_char, nxt_char)].append(i)
        
        self.word_pair_pos = {
                word_pair: indices
                for word_pair, indices in word_pair_pos.items()
                if len(indices) >= min_freq
            }
        self.word_pair_len = {
            word_pair: len(indices)
            for word_pair, indices in self.word_pair_pos.items()
        }
        self.pair_freq_queue = [(-freq, word_pair) for word_pair, freq in self.word_pair_len.items()]
        heapq.heapify(self.pair_freq_queue)
        self.vocab = vocab
        self.corpus = corpus
        self.word_count = deepcopy(self.vocab)
        self.seg_status = array('B', [1] * (len(self.corpus) + 2))
        if verbose:
            print("init finish!")

    def log(self, epoch, comb, freq):
        print(f"epoch: {epoch}\tcomb: {' + '.join(comb)}\tfreq: {freq}")

    
    def most_frequent_combination(self):
        while len(self.pair_freq_queue) > 0:
            cached_freq, comb = heapq.heappop(self.pair_freq_queue)
            ground_freq = self.word_pair_len[comb]
            cached_freq = -cached_freq
            if cached_freq == ground_freq:
                return comb, cached_freq
            elif ground_freq > self.min_freq:
                heapq.heappush(self.pair_freq_queue, (-ground_freq, comb))
                if ground_freq / len(self.word_pair_pos[comb]) < self.compress_threshold:
                    self.word_pair_pos[comb] = self.compress_indices(*comb) 
            else:
                self.word_pair_len.pop(comb, None)
                self.word_pair_pos.pop(comb, None)
        return None, 0
    
    def compress_indices(self, word_a, word_b, indices=None):
        seg = self.seg_status
        len_a, len_b = len(word_a), len(word_b)
        indices = indices or self.word_pair_pos[(word_a, word_b)]
        return array('I', (
            i for i in indices
            if seg[i] == len_a and seg[i + len_a] == len_b
        ))


    def merge_word(self, comb, freq):
        word_a, word_b = comb  
        word_comb = word_a + word_b
        len_a, len_b, len_comb = len(word_a), len(word_b), len(word_comb)
        seg_status = self.seg_status
        word_pair_v2 = self.word_pair_pos
        corpus = self.corpus
        indices = word_pair_v2[comb]
        if len(indices) > freq:
            indices = self.compress_indices(word_a, word_b, indices) 
        indices = preproc_idx(indices, word_a, word_b)

        self.word_count[word_a] -= len(indices)
        self.word_count[word_b] -= len(indices)
        new_pairs = defaultdict(list)
        for i in indices:
            # using seg_status to find the previous and next word
            pre_end, nxt_start = i, i + len_comb
            nxt_end = seg_status[nxt_start] + nxt_start
            pre_start = i - seg_status[i - 1]
            pre_word, nxt_word = corpus[pre_start: pre_end], corpus[nxt_start: nxt_end]
            
            if pre_word != "#":
                try:
                    self.word_pair_len[pre_word, word_a] -= 1
                except KeyError:
                    pass

                if pre_word == word_b and self.get_pre_word(pre_start) == word_a: 
                    # previous 2 words is the same as word_a and word_b
                    # For example“1 0 1 0” to “10 10”
                    new_pairs[word_comb, word_comb].append(pre_start - len_a)
                else:
                    new_pairs[pre_word, word_comb].append(pre_start)
        
            if nxt_word != "#":
                if not (nxt_word == word_a and self.get_nxt_word(nxt_end) == word_b):
                    # next 2 words is not the same as word_a and word_b
                    try:
                        self.word_pair_len[word_b, nxt_word] -= 1
                    except KeyError:
                        pass
                    new_pairs[word_comb, nxt_word].append(i)
        self.update(new_pairs)

        if len_b == 1:
            for i in indices:
                seg_status[i] = len_comb
                seg_status[i + len_a] = len_comb
        else:
            for i in indices:
                seg_status[i] = len_comb
                seg_status[i + len_a] = 0
                seg_status[i + len_comb - 1] = len_comb

        word_pair_v2.pop(comb)
        self.word_pair_len.pop(comb)
       
        if self.word_count[word_a] <= 0:
            p = self.vocab.pop(word_a, None)
            if p:
                print(f"\tremove {word_a}")
        if word_b != word_a and self.word_count[word_b] <= 0:
            p = self.vocab.pop(word_b, None)
            if p:
                print(f"\tremove {word_b}")
        

    def update(self, new_pairs):
        for k, v in new_pairs.items():
            if len(v) >= self.min_freq:
                data = self.word_pair_pos[k] = array("I", v)
                freq = self.word_pair_len[k] = len(data)
                heapq.heappush(self.pair_freq_queue, (-freq, k))

    def get_pre_word(self, init_pos):
        pos = init_pos - self.seg_status[init_pos - 1] 
        return self.corpus[pos: init_pos]

    def get_nxt_word(self, init_pos):
        pos = init_pos + self.seg_status[init_pos]
        return self.corpus[init_pos: pos]

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.