@@ -27,8 +27,10 @@ def _validate_chamfer_reduction_inputs(
2727 """
2828 if batch_reduction is not None and batch_reduction not in ["mean" , "sum" ]:
2929 raise ValueError ('batch_reduction must be one of ["mean", "sum"] or None' )
30- if point_reduction is not None and point_reduction not in ["mean" , "sum" ]:
31- raise ValueError ('point_reduction must be one of ["mean", "sum"] or None' )
30+ if point_reduction is not None and point_reduction not in ["mean" , "sum" , "max" ]:
31+ raise ValueError (
32+ 'point_reduction must be one of ["mean", "sum", "max"] or None'
33+ )
3234 if point_reduction is None and batch_reduction is not None :
3335 raise ValueError ("Batch reduction must be None if point_reduction is None" )
3436
@@ -80,7 +82,6 @@ def _chamfer_distance_single_direction(
8082 x_normals ,
8183 y_normals ,
8284 weights ,
83- batch_reduction : Union [str , None ],
8485 point_reduction : Union [str , None ],
8586 norm : int ,
8687 abs_cosine : bool ,
@@ -103,11 +104,6 @@ def _chamfer_distance_single_direction(
103104 raise ValueError ("weights cannot be negative." )
104105 if weights .sum () == 0.0 :
105106 weights = weights .view (N , 1 )
106- if batch_reduction in ["mean" , "sum" ]:
107- return (
108- (x .sum ((1 , 2 )) * weights ).sum () * 0.0 ,
109- (x .sum ((1 , 2 )) * weights ).sum () * 0.0 ,
110- )
111107 return ((x .sum ((1 , 2 )) * weights ) * 0.0 , (x .sum ((1 , 2 )) * weights ) * 0.0 )
112108
113109 cham_norm_x = x .new_zeros (())
@@ -135,7 +131,10 @@ def _chamfer_distance_single_direction(
135131 if weights is not None :
136132 cham_norm_x *= weights .view (N , 1 )
137133
138- if point_reduction is not None :
134+ if point_reduction == "max" :
135+ assert not return_normals
136+ cham_x = cham_x .max (1 ).values # (N,)
137+ elif point_reduction is not None :
139138 # Apply point reduction
140139 cham_x = cham_x .sum (1 ) # (N,)
141140 if return_normals :
@@ -146,22 +145,34 @@ def _chamfer_distance_single_direction(
146145 if return_normals :
147146 cham_norm_x /= x_lengths_clamped
148147
149- if batch_reduction is not None :
150- # batch_reduction == "sum"
151- cham_x = cham_x .sum ()
152- if return_normals :
153- cham_norm_x = cham_norm_x .sum ()
154- if batch_reduction == "mean" :
155- div = weights .sum () if weights is not None else max (N , 1 )
156- cham_x /= div
157- if return_normals :
158- cham_norm_x /= div
159-
160148 cham_dist = cham_x
161149 cham_normals = cham_norm_x if return_normals else None
162150 return cham_dist , cham_normals
163151
164152
153+ def _apply_batch_reduction (
154+ cham_x , cham_norm_x , weights , batch_reduction : Union [str , None ]
155+ ):
156+ if batch_reduction is None :
157+ return (cham_x , cham_norm_x )
158+ # batch_reduction == "sum"
159+ N = cham_x .shape [0 ]
160+ cham_x = cham_x .sum ()
161+ if cham_norm_x is not None :
162+ cham_norm_x = cham_norm_x .sum ()
163+ if batch_reduction == "mean" :
164+ if weights is None :
165+ div = max (N , 1 )
166+ elif weights .sum () == 0.0 :
167+ div = 1
168+ else :
169+ div = weights .sum ()
170+ cham_x /= div
171+ if cham_norm_x is not None :
172+ cham_norm_x /= div
173+ return (cham_x , cham_norm_x )
174+
175+
165176def chamfer_distance (
166177 x ,
167178 y ,
@@ -197,7 +208,8 @@ def chamfer_distance(
197208 batch_reduction: Reduction operation to apply for the loss across the
198209 batch, can be one of ["mean", "sum"] or None.
199210 point_reduction: Reduction operation to apply for the loss across the
200- points, can be one of ["mean", "sum"] or None.
211+ points, can be one of ["mean", "sum", "max"] or None. Using "max" leads to the
212+ Hausdorff distance.
201213 norm: int indicates the norm used for the distance. Supports 1 for L1 and 2 for L2.
202214 single_directional: If False (default), loss comes from both the distance between
203215 each point in x and its nearest neighbor in y and each point in y and its nearest
@@ -227,6 +239,10 @@ def chamfer_distance(
227239
228240 if not ((norm == 1 ) or (norm == 2 )):
229241 raise ValueError ("Support for 1 or 2 norm." )
242+
243+ if point_reduction == "max" and (x_normals is not None or y_normals is not None ):
244+ raise ValueError ('Normals must be None if point_reduction is "max"' )
245+
230246 x , x_lengths , x_normals = _handle_pointcloud_input (x , x_lengths , x_normals )
231247 y , y_lengths , y_normals = _handle_pointcloud_input (y , y_lengths , y_normals )
232248
@@ -238,13 +254,13 @@ def chamfer_distance(
238254 x_normals ,
239255 y_normals ,
240256 weights ,
241- batch_reduction ,
242257 point_reduction ,
243258 norm ,
244259 abs_cosine ,
245260 )
246261 if single_directional :
247- return cham_x , cham_norm_x
262+ loss = cham_x
263+ loss_normals = cham_norm_x
248264 else :
249265 cham_y , cham_norm_y = _chamfer_distance_single_direction (
250266 y ,
@@ -254,17 +270,23 @@ def chamfer_distance(
254270 y_normals ,
255271 x_normals ,
256272 weights ,
257- batch_reduction ,
258273 point_reduction ,
259274 norm ,
260275 abs_cosine ,
261276 )
262- if point_reduction is not None :
263- return (
264- cham_x + cham_y ,
265- (cham_norm_x + cham_norm_y ) if cham_norm_x is not None else None ,
266- )
267- return (
268- (cham_x , cham_y ),
269- (cham_norm_x , cham_norm_y ) if cham_norm_x is not None else None ,
270- )
277+ if point_reduction == "max" :
278+ loss = torch .maximum (cham_x , cham_y )
279+ loss_normals = None
280+ elif point_reduction is not None :
281+ loss = cham_x + cham_y
282+ if cham_norm_x is not None :
283+ loss_normals = cham_norm_x + cham_norm_y
284+ else :
285+ loss_normals = None
286+ else :
287+ loss = (cham_x , cham_y )
288+ if cham_norm_x is not None :
289+ loss_normals = (cham_norm_x , cham_norm_y )
290+ else :
291+ loss_normals = None
292+ return _apply_batch_reduction (loss , loss_normals , weights , batch_reduction )
0 commit comments