Hi Guys,
FYI, when you are trying changing the num_heads to be a number other than the default num_heads=1, pls be sure to make it a divisor of hidden_units, or just update args.hidden_units = args.hidden_units * args.num_heads after parsing the arguments. It depends on what we mean by hidden_units, in this repo, to be consistent with tf version, we take it as the hidden size for all heads together and did not add a checker to ensure that users use the valid num_heads, hidden_units combination.
I decided to not include such a checker in the code as it depends on your understanding and habit for using multi-head attention, just created this issue to explain the issue.
Regards,
Zan
Hi Guys,
FYI, when you are trying changing the
num_headsto be a number other than the defaultnum_heads=1, pls be sure to make it a divisor ofhidden_units, or just updateargs.hidden_units = args.hidden_units * args.num_headsafter parsing the arguments. It depends on what we mean byhidden_units, in this repo, to be consistent with tf version, we take it as thehidden sizefor all heads together and did not add a checker to ensure that users use the validnum_heads, hidden_unitscombination.I decided to not include such a checker in the code as it depends on your understanding and habit for using multi-head attention, just created this issue to explain the issue.
Regards,
Zan