GithubHelp home page GithubHelp logo

ADD kotoba-whisper-v1.0 about ailia-models HOT 4 OPEN

kyakuno avatar kyakuno commented on June 11, 2024
ADD kotoba-whisper-v1.0

from ailia-models.

Comments (4)

kyakuno avatar kyakuno commented on June 11, 2024 1

@ooe1123 他のモデルが終わった後に、こちらをお願いできると嬉しいです _ _

from ailia-models.

ooe1123 avatar ooe1123 commented on June 11, 2024

〇 transformers/models/whisper/modeling_whisper.py

class WhisperSdpaAttention(WhisperAttention):
    ...
    def forward(
        self,
        ...
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        ...
        if (
            is_cross_attention
            and past_key_value is not None
            and past_key_value[0].shape[2] == key_value_states.shape[1]
        ):
            ...
        elif is_cross_attention:
            ...
        elif past_key_value is not None:
            ...
        else:
            ...

        ...
        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=attention_mask,
            dropout_p=self.dropout if self.training else 0.0,
            # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
            is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
        )

class WhisperDecoder(WhisperPreTrainedModel): 
    ...
    def forward(
        ...
    ):
        ...
        if self._use_flash_attention_2:
            ...
        elif self._use_sdpa and head_mask is None and not output_attentions:
            # output_attentions=True & head_mask can not be supported when using SDPA.
            attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
                attention_mask, input_shape, inputs_embeds, past_key_values_length
            )
        else:
            ...

class WhisperSdpaAttention(WhisperAttention):
    ...
    def forward(
        self,
        ...
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        ...
        if is_cross_attention:
            key_states = torch.cat([past_key_value[0], self._shape(self.k_proj(key_value_states), -1, bsz)], dim=2)
            value_states = torch.cat([past_key_value[1], self._shape(self.v_proj(key_value_states), -1, bsz)], dim=2)
            key_states = key_states[:,:,:1500,:]
            value_states = value_states[:,:,:1500,:]
        elif past_key_value is not None:
            # reuse k, v, self_attention
            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)
        else:
            # self_attention
            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)

        ...
        if torch.onnx.is_in_onnx_export():
            if self.is_causal:
                attn_output_1 = torch.nn.functional.scaled_dot_product_attention(
                    query_states,
                    key_states,
                    value_states,
                    attn_mask=attention_mask,
                    dropout_p=self.dropout if self.training else 0.0,
                    # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
                    is_causal=False
                )
                attn_output_2 = torch.nn.functional.scaled_dot_product_attention(
                    query_states,
                    key_states,
                    value_states,
                    attn_mask=attention_mask,
                    dropout_p=self.dropout if self.training else 0.0,
                    # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
                    is_causal=True
                )
                ind = torch.gt(tgt_len, 1).type(torch.int64)
                sel = torch.stack([attn_output_1, attn_output_2])
                attn_output = sel[ind]
            else:
                attn_output = torch.nn.functional.scaled_dot_product_attention(
                    query_states,
                    key_states,
                    value_states,
                    attn_mask=attention_mask,
                    dropout_p=self.dropout if self.training else 0.0,
                    # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
                    is_causal=False,
                )
        else:
            # オリジナル実装
            attn_output = torch.nn.functional.scaled_dot_product_attention(
                query_states,
                key_states,
                value_states,
                attn_mask=attention_mask,
                dropout_p=self.dropout if self.training else 0.0,
                # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
                is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
            )

class WhisperDecoder(WhisperPreTrainedModel): 
    ...
    def forward(
        ...
    ):
        ...
        if self._use_flash_attention_2:
            ...
        elif self._use_sdpa and head_mask is None and not output_attentions:
            # output_attentions=True & head_mask can not be supported when using SDPA.
            # attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
            #     attention_mask, input_shape, inputs_embeds, past_key_values_length
            # )
            attention_mask = None
        else:
            ...

from ailia-models.

ooe1123 avatar ooe1123 commented on June 11, 2024

〇 transformers/generation/utils.py

class GenerationMixin:
    ...
    def _greedy_search(
        ...
    ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
        ...
        while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
            # prepare model inputs
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

            # forward pass to get next token
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )

class GenerationMixin:
    ...
    def _greedy_search(
        ...
    ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
        ...

        if 1:
            class Net(nn.Module):
                def __init__(self, net):
                    super(Net, self).__init__()
                    self.net = net
                def forward(
                        self, decoder_input_ids, encoder_hidden_states,
                        past_key_values_0_decoder_key, past_key_values_0_decoder_value, past_key_values_0_encoder_key, past_key_values_0_encoder_value, past_key_values_1_decoder_key, past_key_values_1_decoder_value, past_key_values_1_encoder_key, past_key_values_1_encoder_value,
                    ):
                    model_inputs = {
                        "decoder_input_ids": decoder_input_ids,
                        "encoder_outputs": [encoder_hidden_states],
                        "past_key_values": [
                            [
                                past_key_values_0_decoder_key,
                                past_key_values_0_decoder_value,
                                past_key_values_0_encoder_key,
                                past_key_values_0_encoder_value,
                            ],
                            [
                                past_key_values_1_decoder_key,
                                past_key_values_1_decoder_value,
                                past_key_values_1_encoder_key,
                                past_key_values_1_encoder_value,
                            ],
                        ],
                    }
                    outputs = self.net(
                        **model_inputs,
                        return_dict=True,
                        output_attentions=output_attentions,
                        output_hidden_states=output_hidden_states,
                    )
                    # return outputs  # Updated
                    return (
                        outputs["logits"],
                        outputs["past_key_values"][0][0].type(torch.float16),
                        outputs["past_key_values"][0][1].type(torch.float16),
                        outputs["past_key_values"][0][2].type(torch.float16),
                        outputs["past_key_values"][0][3].type(torch.float16),
                        outputs["past_key_values"][1][0].type(torch.float16),
                        outputs["past_key_values"][1][1].type(torch.float16),
                        outputs["past_key_values"][1][2].type(torch.float16),
                        outputs["past_key_values"][1][3].type(torch.float16),
                    )

            model = Net(self)
        
        while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
            # prepare model inputs
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

            # Add
            if model_inputs["past_key_values"] is None:
                b = model_inputs["encoder_outputs"][0].size(0)
                d = model_inputs["encoder_outputs"][0].device
                model_inputs["past_key_values"] = [
                    [
                        torch.zeros(b, 20, 0, 64, dtype=torch.float16).to(d),
                        torch.zeros(b, 20, 0, 64, dtype=torch.float16).to(d),
                        torch.zeros(b, 20, 0, 64, dtype=torch.float16).to(d),
                        torch.zeros(b, 20, 0, 64, dtype=torch.float16).to(d),
                    ]
                ] * 2

            if 1 and 0 < model_inputs["past_key_values"][0][0].size(2):
                print("------>")
                from torch.autograd import Variable
                xx = (
                    Variable(model_inputs["decoder_input_ids"]),
                    Variable(model_inputs["encoder_outputs"].last_hidden_state),
                    Variable(model_inputs["past_key_values"][0][0]),
                    Variable(model_inputs["past_key_values"][0][1]),
                    Variable(model_inputs["past_key_values"][0][2]),
                    Variable(model_inputs["past_key_values"][0][3]),
                    Variable(model_inputs["past_key_values"][1][0]),
                    Variable(model_inputs["past_key_values"][1][1]),
                    Variable(model_inputs["past_key_values"][1][2]),
                    Variable(model_inputs["past_key_values"][1][3]),
                )
                torch.onnx.export(
                    model, xx, 'decoder_model.onnx',
                    input_names=[
                       'input_ids', 'encoder_hidden_states', 'past_key_values.0.decoder.key', 'past_key_values.0.decoder.value', 'past_key_values.0.encoder.key', 'past_key_values.0.encoder.value', 'past_key_values.1.decoder.key', 'past_key_values.1.decoder.value', 'past_key_values.1.encoder.key', 'past_key_values.1.encoder.value', 
                    ],
                    output_names=[
                        'logits',
                        'present.0.decoder.key', 'present.0.decoder.value', 'present.0.encoder.key', 'present.0.encoder.value', 'present.1.decoder.key', 'present.1.decoder.value', 'present.1.encoder.key', 'present.1.encoder.value',
                    ],
                    dynamic_axes={
                        'input_ids': {0: 'batch_size', 1: 'decoder_sequence_length'},
                        'encoder_hidden_states': {0: 'batch_size', 1: 'encoder_sequence_length / 2'},
                        'logits': {0: 'batch_size', 1: 'decoder_sequence_length'},
                        'past_key_values.0.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.0.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.0.encoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.0.encoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.1.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.1.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.1.encoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.1.encoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'present.0.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
                        'present.0.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
                        'present.0.encoder.key': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
                        'present.0.encoder.value': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
                        'present.1.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
                        'present.1.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
                        'present.1.encoder.key': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
                        'present.1.encoder.value': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
                    },
                    verbose=False, opset_version=14
                )
                print("<------")
                exit(0)

from ailia-models.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.