GithubHelp home page GithubHelp logo

Comments (17)

dimidd avatar dimidd commented on June 28, 2024 3

Thanks, for my use case (serving a model as an api), a contextmanager doesn't fit, since I need to call predict after an external event (e.g. an http request), so I'm just calling _cached_inference directly.
Anyhow, I think we can finally close this issue. Thanks a lot for your great work!

from finetune.

madisonmay avatar madisonmay commented on June 28, 2024 1

This is imperfect but this is some WIP code might be helpful for you to use as a starting point.

def _data_generator(self):
    while not self._closed:
        yield self._data.pop(0)

def _inference(self, Xs, mode=None):
    self._data = Xs
    n = len(Xs)

    if self.__class__ == SequenceLabeler:
        self._data = [[x] for x in self._data]

    if not getattr(self, 'estimator', None):
        self.estimator = self.get_estimator()
        self._closed = False
        dataset = lambda: self.input_pipeline._dataset_without_targets(self._data_generator, train=None).batch(1)
        self.predictions = self.estimator.predict(input_fn=dataset, yield_single_examples=True)

    _predictions = []
    for _ in range(n):
        try:
            y = next(self.predictions)
        except:
            raise e
        y = y[mode] if mode else y
        _predictions.append(y)
    return _predictions

from finetune.

madisonmay avatar madisonmay commented on June 28, 2024 1

I think we've found a way to have our cake and eat it too without complicating the user interface. Just padding out the final batches should allow us to get a batch speedup but not have to recompile the predict function. PR in progress at #193

from finetune.

dimidd avatar dimidd commented on June 28, 2024 1

Hola Guillermo,

I'm getting a sub-second end-to-end times (as measured from the web interface) using flask.
See here for details: #153

from finetune.

madisonmay avatar madisonmay commented on June 28, 2024 1

Hi @dimidd,

Thanks for checking back in! Although I was hoping to end up with a solution where we could have our metaphorical cake and eat it too, we ran into some limitations with how tensorflow handles cleaning up memory that meant we had to opt for a more explicit interface for prediction if you want to avoid rebuilding the graph: https://finetune.indico.io/#prediction

model = Classifier()
model.fit(train_data, train_labels)
with model.cached_predict():
    model.predict(test_data) # triggers prediction graph construction
    model.predict(test_data) # graph is already cached, so subsequence calls are faster

Let me know if this solution works for you!

from finetune.

dimidd avatar dimidd commented on June 28, 2024

This could be a possible solution
https://raw.githubusercontent.com/marcsto/rl/master/src/fast_predict2.py
Details:
tensorflow/tensorflow#4648

from finetune.

madisonmay avatar madisonmay commented on June 28, 2024

Yes! It's the fact that the tf Estimator API rebuilds the graph on every call to predict that's the problem. There's some tricky logic around making sure you can still batch properly if you keep a generator open, but this is absolutely the right way to go.

from finetune.

dimidd avatar dimidd commented on June 28, 2024

Thanks! I've tried to import SequenceLabeler in base.py, but this caused a strange error:

~/anaconda3/envs/tensorflow_p36_new_ft/lib/python3.6/site-packages/finetune/base.py in <module>()
     30 from finetune.download import download_data_if_required
     31 from finetune.estimator_utils import PatchedParameterServerStrategy
---> 32 from finetune.sequence_labeling import SequenceLabeler
     33
     34 JL_BASE = os.path.join(os.path.dirname(__file__), "model", "Base_model.jl")

~/anaconda3/envs/tensorflow_p36_new_ft/lib/python3.6/site-packages/finetune/sequence_labeling.py in <module>()
      7 import numpy as np
      8
----> 9 from finetune.base import BaseModel, PredictMode
     10 from finetune.target_encoders import SequenceLabelingEncoder, SequenceMultiLabelingEncoder
     11 from finetune.network_modules import sequence_labeler

ImportError: cannot import name 'BaseModel'

from finetune.

madisonmay avatar madisonmay commented on June 28, 2024

@dimidd you have a circular import reference. For now you can probably just delete the SequenceLabeler specific code and have this work for the other model types. Alternately you could move the check to be

if self.__class__.__name__ == "SequenceLabeler"

or similar

from finetune.

madisonmay avatar madisonmay commented on June 28, 2024

I think as a long term solution we need to refactor things to prevent having to override the _inference method in SequenceLabeler. But thanks for taking a look at this issue!

from finetune.

dimidd avatar dimidd commented on June 28, 2024

Great! Thank you! It's very fast now. I'll close the issue when this is merged.

from finetune.

madisonmay avatar madisonmay commented on June 28, 2024

@dimidd The tricky part is that the old behavior still makes sense in certain scenarios because it's batched. So if you need to predict over a large amount of data in a single call to "predict" that will still be faster, because the lazy generator solution uses a batch size of 1. I think there's a way to have our cake and eat it too (use a generator but have a dynamic batch size or yield batches instead of single examples) but not sure on the details. LMK if you find a good solution -- if we can get around that problem I fully support doing this by default.

from finetune.

dimidd avatar dimidd commented on June 28, 2024

Hi Madison,

IMHO we shouldn't square the circle, but rather let the user decide. The user usually knows in advance whether the batch version or the online one is needed. The downside is having two versions for each method. e.g. predict_batch and predict_online.

from finetune.

Guillermogsjc avatar Guillermogsjc commented on June 28, 2024

from finetune.

dimidd avatar dimidd commented on June 28, 2024

Using predict on 0.5.12, I can get much better results than 0.5.11, around 2 seconds per prediction, but it's still not instantaneous like calling _inference2 with Madison's initial suggestion. I can provide more exact stats next week, if you'd like to perfect it.

from finetune.

dimidd avatar dimidd commented on June 28, 2024

Ah, sorry, I should have guessed the docs will get updated. I'll try next week.

from finetune.

madisonmay avatar madisonmay commented on June 28, 2024

No worries, should have updated this thread!

from finetune.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.