@@ -153,6 +153,86 @@ def generate_proposals_ref():
153153 torch .testing .assert_allclose (rois , a )
154154 torch .testing .assert_allclose (rois_probs , b )
155155
156+ @given (
157+ bsz = st .integers (1 , 5 ),
158+ seq_lens = st .integers (1 , 6 ),
159+ emb_lens = st .integers (5 , 10 ),
160+ hidden_size = st .integers (3 , 7 ),
161+ num_layers = st .integers (1 , 4 ),
162+ has_biases = st .booleans (),
163+ is_bidirectional = st .booleans (),
164+ batch_first = st .booleans (),
165+ )
166+ def test_inference_lstm (
167+ self ,
168+ bsz ,
169+ seq_lens ,
170+ emb_lens ,
171+ hidden_size ,
172+ num_layers ,
173+ has_biases ,
174+ is_bidirectional ,
175+ batch_first ,
176+ ):
177+ num_directions = 2 if is_bidirectional else 1
178+ hx = np .zeros ((num_layers * num_directions , bsz , hidden_size ), dtype = np .float32 )
179+
180+ if batch_first :
181+ inputs = np .random .randn (bsz , seq_lens , emb_lens ).astype (np .float32 )
182+ else :
183+ inputs = np .random .randn (seq_lens , bsz , emb_lens ).astype (np .float32 )
184+
185+ torch_lstm = torch .nn .LSTM (
186+ emb_lens ,
187+ hidden_size ,
188+ batch_first = batch_first ,
189+ bidirectional = is_bidirectional ,
190+ bias = has_biases ,
191+ num_layers = num_layers ,
192+ )
193+
194+ def inference_lstm_ref ():
195+ input_names = ["inputs" , "hidden_0" , "hidden_1" ]
196+ workspace .FeedBlob ("inputs" , inputs )
197+ workspace .FeedBlob ("hidden_0" , hx )
198+ workspace .FeedBlob ("hidden_1" , hx )
199+ for i , param in enumerate (torch_lstm ._flat_weights ):
200+ input_names .append ("param_{}" .format (i ))
201+ workspace .FeedBlob ("param_{}" .format (i ), param .detach ().numpy ())
202+
203+ ref_op = core .CreateOperator (
204+ "InferenceLSTM" ,
205+ input_names ,
206+ ["output" , "hidden" , "cell" ],
207+ num_layers = num_layers ,
208+ has_biases = has_biases ,
209+ batch_first = batch_first ,
210+ bidirectional = is_bidirectional ,
211+ )
212+ workspace .RunOperatorOnce (ref_op )
213+ return (
214+ workspace .FetchBlob ("output" ),
215+ workspace .FetchBlob ("hidden" ),
216+ workspace .FetchBlob ("cell" )
217+ )
218+
219+ output , hidden , cell = inference_lstm_ref ()
220+ output = torch .tensor (output )
221+ hidden = torch .tensor (hidden )
222+ cell = torch .tensor (cell )
223+ lstm_in = [
224+ torch .from_numpy (inputs ),
225+ torch .from_numpy (hx ),
226+ torch .from_numpy (hx ),
227+ ] + [param .detach () for param in torch_lstm ._flat_weights ]
228+
229+ a , b , c = torch .ops ._caffe2 .InferenceLSTM (
230+ lstm_in , num_layers , has_biases , batch_first , is_bidirectional
231+ )
232+ torch .testing .assert_allclose (output , a )
233+ torch .testing .assert_allclose (hidden , b )
234+ torch .testing .assert_allclose (cell , c )
235+
156236 # Test case is using workspace.has_cuda_support and not workspace.has_gpu_support
157237 # to exclude it from HIP because tensor interop doesn't work for HIP tensors yet
158238 @unittest .skipIf (not workspace .has_cuda_support , "No cuda support" )
0 commit comments