@@ -67,16 +67,7 @@ struct TORCH_API StringView {
6767// Soft limit on the number of callbacks to use;
6868constexpr std::size_t kSoftLimitCallbacks = 4 ;
6969
70- // An abstract base class for various observer contexts that can be attached to
71- // the RecordFunction.
72- struct ObserverContext {
73- virtual ~ObserverContext () {}
74- protected:
75- ObserverContext () {}
76- };
77-
7870typedef c10::SmallVector<uint64_t , kSoftLimitCallbacks > CallbackHandles;
79- typedef std::vector<std::unique_ptr<ObserverContext>> ObserverContextList;
8071typedef uint64_t RecordFunctionHandle;
8172
8273struct TORCH_API RecordFunction {
@@ -173,15 +164,6 @@ struct TORCH_API RecordFunction {
173164 // public because of anonymous "friend" class
174165 CallbackHandles sorted_active_tls_handles_;
175166 CallbackHandles sorted_active_global_handles_;
176-
177- // Stores various ObserverContext objects with event metadata for thread local
178- // callbacks.
179- ObserverContextList tls_ctx_;
180-
181- // Stores various ObserverContext objects with event metadata for global
182- // callbacks.
183- ObserverContextList global_ctx_;
184-
185167 // Whether this RecordFunction runs any callbacks
186168 bool active = false ;
187169 // / Whether any of the picked callbacks require inputs
@@ -216,8 +198,6 @@ struct TORCH_API RecordFunction {
216198 * RecordFunctionCallback represents a pair of callbacks to be used with
217199 * RecordFunction, members:
218200 * start, end - the callbacks to run when entering and exiting the scope;
219- * optionally, the start callback may return an ObserverContext which will
220- * be passed to the end callback, use appropriate constructor accordingly.
221201 * needs_inputs - whether the callbacks need the inputs passed from the observed
222202 * function/range; NOTE: passing the inputs incurs an additional overhead;
223203 * sampling_probability - if not 1.0, then the callback is probabilistically sampled
@@ -231,25 +211,12 @@ struct TORCH_API RecordFunction {
231211 */
232212class TORCH_API RecordFunctionCallback {
233213 public:
234- // This interface supports observers that require passing an ObserverContext
235- // between start and end callbacks.
236- explicit RecordFunctionCallback (
237- std::function<std::unique_ptr<ObserverContext>(const RecordFunction&)> start,
238- std::function<void(const RecordFunction&, ObserverContext*)> end =
239- [](const RecordFunction&, ObserverContext*) {}):
240- start_ (std::move(start)),
241- end_ (std::move(end)) {
242- scopes_.fill (true );
243- }
244-
245- // This interface is for observers that do not pass an ObserverContext object
246- // between start and end callbacks.
247214 explicit RecordFunctionCallback (
248215 std::function<void (const RecordFunction&)> start,
249216 std::function<void(const RecordFunction&)> end =
250217 [](const RecordFunction&) {}):
251- start_{[start]( const RecordFunction& rf) { start (rf); return nullptr ; }} ,
252- end_{[end]( const RecordFunction& rf, ObserverContext*) { end (rf); }} {
218+ start_ (std::move(start)) ,
219+ end_ (std::move(end)) {
253220 scopes_.fill (true );
254221 }
255222
@@ -305,20 +272,20 @@ class TORCH_API RecordFunctionCallback {
305272 return scopes_[(size_t )sc];
306273 }
307274
308- inline const std::function<std::unique_ptr<ObserverContext> (const RecordFunction&)>& start () const {
275+ inline const std::function<void (const RecordFunction&)>& start () const {
309276 return start_;
310277 }
311278
312- inline const std::function<void (const RecordFunction&, ObserverContext* )>& end () const {
279+ inline const std::function<void (const RecordFunction&)>& end () const {
313280 return end_;
314281 }
315282
316283 // whether the callbacks should run in the given scope
317284 bool shouldRun (RecordScope scope) const ;
318285
319286 private:
320- std::function<std::unique_ptr<ObserverContext> (const RecordFunction&)> start_;
321- std::function<void (const RecordFunction&, ObserverContext* )> end_;
287+ std::function<void (const RecordFunction&)> start_;
288+ std::function<void (const RecordFunction&)> end_;
322289 std::function<bool (const RecordFunctionCallback&)> should_run_;
323290 bool needs_inputs_ = false ;
324291 bool needs_ids_ = false ;
0 commit comments