@@ -79,7 +79,6 @@ class OperatorState {
7979 public:
8080 OperatorState (Operator *opr, const OperatorProperty *prop) {
8181 opr_ = opr;
82- fwd_init_ = bwd_init_ = false ;
8382
8483 in_data_fwd_.resize (prop->ListArguments ().size ());
8584 in_data_bwd_.resize (prop->ListArguments ().size ());
@@ -110,47 +109,39 @@ class OperatorState {
110109 const std::vector<TBlob>& inputs,
111110 const std::vector<OpReqType>& req,
112111 const std::vector<TBlob>& outputs) {
113- if (!fwd_init_) {
114- CHECK_EQ (inputs.size (), in_data_fwd_.size () + aux_data_.size ());
115- CHECK_EQ (outputs.size (), out_data_.size ());
116- // in_data_bwd_ has the same tblobs as the ones in in_data_fwd_, except that the ones
117- // referred by arg_data_ptr_ will be overriden
118- for (size_t i = 0 ; i < in_data_fwd_.size (); ++i) in_data_fwd_[i] = inputs[i];
119- for (size_t i = 0 ; i < in_data_fwd_.size (); ++i) in_data_bwd_[i] = inputs[i];
120- for (size_t i = 0 ; i < aux_data_.size (); ++i) {
121- aux_data_[i] = inputs[i + in_data_fwd_.size ()];
122- }
123- for (size_t i = 0 ; i < out_data_.size (); ++i) out_data_[i] = outputs[i];
124- fwd_init_ = true ;
112+ CHECK_EQ (inputs.size (), in_data_fwd_.size () + aux_data_.size ());
113+ CHECK_EQ (outputs.size (), out_data_.size ());
114+ // in_data_bwd_ has the same tblobs as the ones in in_data_fwd_, except that the ones
115+ // referred by arg_data_ptr_ will be overriden
116+ for (size_t i = 0 ; i < in_data_fwd_.size (); ++i) in_data_fwd_[i] = inputs[i];
117+ for (size_t i = 0 ; i < in_data_fwd_.size (); ++i) in_data_bwd_[i] = inputs[i];
118+ for (size_t i = 0 ; i < aux_data_.size (); ++i) {
119+ aux_data_[i] = inputs[i + in_data_fwd_.size ()];
125120 }
121+ for (size_t i = 0 ; i < out_data_.size (); ++i) out_data_[i] = outputs[i];
126122 opr_->Forward (ctx, in_data_fwd_, req, out_data_, aux_data_);
127123 }
128124
129125 void Backward (const OpContext &ctx,
130126 const std::vector<TBlob>& inputs,
131127 const std::vector<OpReqType>& req,
132128 const std::vector<TBlob>& outputs) {
133- if (!bwd_init_) {
134- CHECK (fwd_init_);
135- CHECK_EQ (arg_data_ptr_.size () + aux_data_.size (), inputs.size ());
136- // override tblobs pointed by arg_data_ptr_ since they might not contain
137- // initialized data during forward pass.
138- for (size_t i = 0 ; i < arg_data_ptr_.size (); ++i) {
139- *arg_data_ptr_[i] = inputs[i];
140- }
141- for (size_t i = 0 ; i < aux_data_.size (); ++i) {
142- aux_data_[i] = inputs[inputs.size () - aux_data_.size () + i];
143- }
144- CHECK_EQ (outputs.size (), in_grad_.size ());
145- for (size_t i = 0 ; i < outputs.size (); ++i) in_grad_[i] = outputs[i];
146- bwd_init_ = true ;
129+ CHECK_EQ (arg_data_ptr_.size () + aux_data_.size (), inputs.size ());
130+ // override tblobs pointed by arg_data_ptr_ since they might not contain
131+ // initialized data during forward pass.
132+ for (size_t i = 0 ; i < arg_data_ptr_.size (); ++i) {
133+ *arg_data_ptr_[i] = inputs[i];
134+ }
135+ for (size_t i = 0 ; i < aux_data_.size (); ++i) {
136+ aux_data_[i] = inputs[inputs.size () - aux_data_.size () + i];
147137 }
138+ CHECK_EQ (outputs.size (), in_grad_.size ());
139+ for (size_t i = 0 ; i < outputs.size (); ++i) in_grad_[i] = outputs[i];
148140 opr_->Backward (ctx, out_grad_, in_data_bwd_, out_data_, req, in_grad_, aux_data_);
149141 }
150142
151143 private:
152144 Operator *opr_;
153- bool fwd_init_, bwd_init_;
154145 // input data blobs for forward and backward
155146 // in_data_fwd_ and in_data_bwd_ will hold different tblobs when StorageFallbackOpExecutor
156147 // performs storage fallback on a non-default input NDArray. The one in in_data_fwd_ is
0 commit comments