torch.nn.LSTMのややこしいパラメータ
torch.nn.LSTMのややこしいパラメータ
LSTMに渡すテンソルはどうするのが正しい?
- LSTMに渡すinputは, 3次元のテンソル 「文章の長さ × バッチサイズ × ベクトル次元数」, とQiitaにある
- StackOverflowで見ると 「バッチサイズ × 文章の長さ × ベクトル次元数」とある
- 公式ガイドで見ると、「文章の長さ × バッチサイズ × ベクトル次元数」とある
input of shape (seq_len, batch, input_size): tensor containing the features of the input sequence. The input can also be a packed variable length sequence.
結論
- どっちやねん、と思ったがデフォルトでは「文章の長さ × バッチサイズ × ベクトル次元数」
- `batch_first=True` に設定すると「バッチサイズ × 文章の長さ × ベクトル次元数」ということらしい
batch_first – If True, then the input and output tensors are provided as (batch, seq, feature). Default: False