assert (
self.config.model.attention.heads % n_heads_groups == 0
), f"{self.config.model.attention.heads} % {n_heads_groups} != 0"
I can run it fine with 4, 8 and 16 IPUs. But with 12 IPUs it throws the following error:
Traceback (most recent call last):
File "run-inference.py", line 65, in <module>
main(args.config)
File "run-inference.py", line 55, in main
run_inference_popxl(config, tokenizer, hf_model=hf_model, sequence_length=2048)
File "run-inference.py", line 15, in run_inference_popxl
pipe = LlamaPipeline(config, hf_llama_checkpoint=hf_model, tokenizer=tokenizer)
File "/scratch/user/u.ac/Gradient-HuggingFace/llama2-chatbot/api/pipeline.py", line 96, in __init__
session: popxl.Session = inference(config)
File "/scratch/user/u.ac/Gradient-HuggingFace/llama2-chatbot/inference.py", line 70, in inference
layer_facts, layer_graph = LlamaDecoderBlockTP(config).create_graph(*embeddings_graph.graph.outputs)
File "/scratch/user/u.ac/Gradient-HuggingFace/llama2-chatbot/modelling/decoder.py", line 28, in __init__
self.attention = LlamaSelfAttentionTP(self.config)
File "/scratch/user/u.ac/Gradient-HuggingFace/llama2-chatbot/modelling/attention.py", line 125, in __init__
self.heads = LlamaAttentionHeads(config=config, replica_grouping=self.replica_grouping)
File "/scratch/user/u.ac/Gradient-HuggingFace/llama2-chatbot/modelling/attention.py", line 44, in __init__
assert (
AssertionError: 32 % 12 != 0