GithubHelp home page GithubHelp logo

Multi-task BO with Service API about ax HOT 6 CLOSED

sgbaird avatar sgbaird commented on September 1, 2024 1
Multi-task BO with Service API

from ax.

Comments (6)

danielcohenlive avatar danielcohenlive commented on September 1, 2024 1

Great question @sgbaird! This is something we do internally in the service API with batch trials. We do have future plans to open source our AxBatchClient, but it's unfortunately not out yet.

With batch trials, you'd have a GS consisting of

  1. SOBOL
  2. GPEI or BOTORCH_MODULAR without fixed_features and status_quo_features
  • At this point, there's only one trial, so you can't do multitask yet.
  1. ST_MTGP or BOTORCH_MODULAR with fixed_features and status_quo_features
    The fixed_features and status_quo_features are going to point to a trial index, so you'd want those to both point to the target trial, probably the most recent one. I'm not aware of any way to group non batch trials into tasks.

What are you trying to do? I noticed honegumi in the prompt. Is this for the honegumi interface or a real world use case? Is it intentional or accidental that this use case has non batch trials?

from ax.

saitcakmak avatar saitcakmak commented on September 1, 2024 1

Hi @sgbaird. Multi-task BO can take many forms depending on what you're trying to achieve. In any case, you need to provide Ax with a way to identify what task each trial belongs to. One way to do this would be to add a task parameter in your search space. When generating new trials, you can specify what task you want them to be generated as well. Here's an example using AxClient.

import numpy as np
from ax.core.observation import ObservationFeatures
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import Models
from ax.modelbridge.transforms.task_encode import TaskChoiceToIntTaskChoice
from ax.modelbridge.transforms.unit_x import UnitX
from ax.service.ax_client import AxClient, ObjectiveProperties
from ax.utils.common.typeutils import not_none
from ax.utils.measurement.synthetic_functions import hartmann6
from ax.utils.notebook.plotting import init_notebook_plotting, render

init_notebook_plotting()


# Update as needed. See ax/modelbridge/registry for default list of transforms.
transforms= [
    TaskChoiceToIntTaskChoice,  # Since we're using a string valued task parameter.
    UnitX,  
]

# Custom generation strategy to support the multi-task search space.
generation_strategy = GenerationStrategy(
    name="MultiTaskMBM",
    steps=[
        GenerationStep(
            model=Models.SOBOL,
            num_trials=5,
            model_kwargs={"deduplicate": True, "transforms": transforms},
        ),
        GenerationStep(
            model=Models.BOTORCH_MODULAR,
            num_trials=-1,
            model_kwargs={"transforms": transforms},
        ),
    ],
)

ax_client = AxClient(generation_strategy=generation_strategy)

ax_client.create_experiment(
    name="hartmann_test_experiment",
    parameters=[
        {
            "name": "x1",
            "type": "range",
            "bounds": [0.0, 1.0],
        },
        {
            "name": "x2",
            "type": "range",
            "bounds": [0.0, 1.0],
        },
        {
            "name": "x3",
            "type": "range",
            "bounds": [0.0, 1.0],
        },
        {
            "name": "x4",
            "type": "range",
            "bounds": [0.0, 1.0],
        },
        {
            "name": "x5",
            "type": "range",
            "bounds": [0.0, 1.0],
        },
        {
            "name": "x6",
            "type": "range",
            "bounds": [0.0, 1.0],
        },
        # Add the task parameter!
        {
            "name": "task",
            "type": "choice",
            "values": ["base", "shifted"],
            "is_task": True,
            "target_value": "base"
        }
    ],
    objectives={"hartmann6": ObjectiveProperties(minimize=True)},
)


# Evaluation produces different results based on task value.
def evaluate(parameterization):
    x = np.array([parameterization.get(f"x{i+1}") for i in range(6)])
    value = hartmann6(x)
    if parameterization.get("task") == "shifted":
        value += 100
    # In our case, standard error is 0, since we are computing a synthetic function.
    return {"hartmann6": (value, 0.0)}



for i in range(10):
    trial = ax_client.experiment.new_trial(
        generator_run=ax_client.generation_strategy.gen(
            experiment=ax_client.experiment,
            n=1,
            pending_observations=ax_client._get_pending_observation_features(
                experiment=ax_client.experiment
            ),
            # Need to specify what task we want to generate from. Switching between the two here.
            fixed_features=ObservationFeatures(
                {"task": "base" if i % 2 else "shifted"}
            ),
        )
    )
    trial.mark_running(no_runner_required=True)
    parameterization, trial_index = not_none(trial.arm).parameters, trial.index
    ax_client.complete_trial(
        trial_index=trial_index, raw_data=evaluate(parameterization)
    )

# We can verify that the model is a ModelListGP of MultiTaskGP.
mb = ax_client.generation_strategy.model
mb.model.surrogate.model

from ax.

sgbaird avatar sgbaird commented on September 1, 2024 1

Hi @sgbaird. Multi-task BO can take many forms depending on what you're trying to achieve. In any case, you need to provide Ax with a way to identify what task each trial belongs to. One way to do this would be to add a task parameter in your search space. When generating new trials, you can specify what task you want them to be generated as well. Here's an example using AxClient.

Thank you for this! I was able to run it and plan to do some additional testing.

from ax.

sgbaird avatar sgbaird commented on September 1, 2024

Great question @sgbaird! This is something we do internally in the service API with batch trials. We do have future plans to open source our AxBatchClient, but it's unfortunately not out yet.

With batch trials, you'd have a GS consisting of

  1. SOBOL
  2. GPEI or BOTORCH_MODULAR without fixed_features and status_quo_features
  • At this point, there's only one trial, so you can't do multitask yet.
  1. ST_MTGP or BOTORCH_MODULAR with fixed_features and status_quo_features
    The fixed_features and status_quo_features are going to point to a trial index, so you'd want those to both point to the target trial, probably the most recent one. I'm not aware of any way to group non batch trials into tasks.

What are you trying to do? I noticed honegumi in the prompt. Is this for the honegumi interface or a real world use case? Is it intentional or accidental that this use case has non batch trials?

Sorry this took so long to get back to you! I lost track of this. Yes, the idea was for Honegumi and a BO tutorial on multi-task, since there are a lot of chemistry and materials science use-cases like this. I was trying to keep it in simplest case, so I hadn't considered/included batch trials. If there's not an easy way to support batch trials, then I'd set it so the batch option within the synchrony row in Honegumi would be crossed out if multi-task is set to True.

from ax.

lena-kashtelyan avatar lena-kashtelyan commented on September 1, 2024

@sgbaird is this still open or resolved? : )

from ax.

sgbaird avatar sgbaird commented on September 1, 2024

Will consider solved for now, and post back/reopen if I run into any issues! Thanks all for the help! 🙂

from ax.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.