Skip to content

Commit 9eb0f43

Browse files
ahhegazyfacebook-github-bot
authored andcommitted
Inference LSTM integration test (#18559)
Summary: Pull Request resolved: #18559 Adding integration test for inference LSTM Reviewed By: houseroad Differential Revision: D14656698 fbshipit-source-id: 80fb2a72be30fcb695f4471b72bf9d6e3965bf81
1 parent aa20591 commit 9eb0f43

1 file changed

Lines changed: 80 additions & 0 deletions

File tree

caffe2/python/operator_test/torch_integration_test.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,86 @@ def generate_proposals_ref():
153153
torch.testing.assert_allclose(rois, a)
154154
torch.testing.assert_allclose(rois_probs, b)
155155

156+
@given(
157+
bsz=st.integers(1, 5),
158+
seq_lens=st.integers(1, 6),
159+
emb_lens=st.integers(5, 10),
160+
hidden_size=st.integers(3, 7),
161+
num_layers=st.integers(1, 4),
162+
has_biases=st.booleans(),
163+
is_bidirectional=st.booleans(),
164+
batch_first=st.booleans(),
165+
)
166+
def test_inference_lstm(
167+
self,
168+
bsz,
169+
seq_lens,
170+
emb_lens,
171+
hidden_size,
172+
num_layers,
173+
has_biases,
174+
is_bidirectional,
175+
batch_first,
176+
):
177+
num_directions = 2 if is_bidirectional else 1
178+
hx = np.zeros((num_layers * num_directions, bsz, hidden_size), dtype=np.float32)
179+
180+
if batch_first:
181+
inputs = np.random.randn(bsz, seq_lens, emb_lens).astype(np.float32)
182+
else:
183+
inputs = np.random.randn(seq_lens, bsz, emb_lens).astype(np.float32)
184+
185+
torch_lstm = torch.nn.LSTM(
186+
emb_lens,
187+
hidden_size,
188+
batch_first=batch_first,
189+
bidirectional=is_bidirectional,
190+
bias=has_biases,
191+
num_layers=num_layers,
192+
)
193+
194+
def inference_lstm_ref():
195+
input_names = ["inputs", "hidden_0", "hidden_1"]
196+
workspace.FeedBlob("inputs", inputs)
197+
workspace.FeedBlob("hidden_0", hx)
198+
workspace.FeedBlob("hidden_1", hx)
199+
for i, param in enumerate(torch_lstm._flat_weights):
200+
input_names.append("param_{}".format(i))
201+
workspace.FeedBlob("param_{}".format(i), param.detach().numpy())
202+
203+
ref_op = core.CreateOperator(
204+
"InferenceLSTM",
205+
input_names,
206+
["output", "hidden", "cell"],
207+
num_layers=num_layers,
208+
has_biases=has_biases,
209+
batch_first=batch_first,
210+
bidirectional=is_bidirectional,
211+
)
212+
workspace.RunOperatorOnce(ref_op)
213+
return (
214+
workspace.FetchBlob("output"),
215+
workspace.FetchBlob("hidden"),
216+
workspace.FetchBlob("cell")
217+
)
218+
219+
output, hidden, cell = inference_lstm_ref()
220+
output = torch.tensor(output)
221+
hidden = torch.tensor(hidden)
222+
cell = torch.tensor(cell)
223+
lstm_in = [
224+
torch.from_numpy(inputs),
225+
torch.from_numpy(hx),
226+
torch.from_numpy(hx),
227+
] + [param.detach() for param in torch_lstm._flat_weights]
228+
229+
a, b, c = torch.ops._caffe2.InferenceLSTM(
230+
lstm_in, num_layers, has_biases, batch_first, is_bidirectional
231+
)
232+
torch.testing.assert_allclose(output, a)
233+
torch.testing.assert_allclose(hidden, b)
234+
torch.testing.assert_allclose(cell, c)
235+
156236
# Test case is using workspace.has_cuda_support and not workspace.has_gpu_support
157237
# to exclude it from HIP because tensor interop doesn't work for HIP tensors yet
158238
@unittest.skipIf(not workspace.has_cuda_support, "No cuda support")

0 commit comments

Comments
 (0)