@@ -51,10 +51,12 @@ class WriterState(Enum):
5151 HEADER = auto ()
5252 KV_DATA = auto ()
5353 TI_DATA = auto ()
54+ WEIGHTS = auto ()
5455
5556
5657class GGUFWriter :
5758 fout : BufferedWriter | None
59+ path : os .PathLike [str ] | str | None
5860 temp_file : tempfile .SpooledTemporaryFile [bytes ] | None
5961 tensors : dict [str , TensorInfo ]
6062 kv_data : dict [str , GGUFValue ]
@@ -77,7 +79,8 @@ def __init__(
7779 self , path : os .PathLike [str ] | str | None , arch : str , use_temp_file : bool = False ,
7880 endianess : GGUFEndian = GGUFEndian .LITTLE ,
7981 ):
80- self .fout = open (path , "wb" ) if path is not None else None
82+ self .fout = None
83+ self .path = path
8184 self .arch = arch
8285 self .endianess = endianess
8386 self .data_alignment = GGUF_DEFAULT_ALIGNMENT
@@ -88,19 +91,29 @@ def __init__(
8891 logger .info ("gguf: This GGUF file is for {0} Endian only" .format (
8992 "Big" if self .endianess == GGUFEndian .BIG else "Little" ,
9093 ))
91- self .state = WriterState .NO_FILE if self . fout is None else WriterState . EMPTY
94+ self .state = WriterState .NO_FILE
9295
9396 self .add_architecture ()
9497
95- def write_header_to_file (self , path : os .PathLike [str ] | str | None = None ) -> None :
96- # NOTE: not checking for WriterState.NO_FILE,
97- # because writing can technically be started over from any state,
98- # as long as a new path is provided
98+ def open_output_file (self , path : os .PathLike [str ] | str | None = None ) -> None :
99+ if self .state is WriterState .EMPTY and self .fout is not None and (path is None or path == self .path ):
100+ # allow calling this multiple times as long as the path is the same
101+ return
102+ if self .state is not WriterState .NO_FILE :
103+ raise ValueError (f'Expected output file to be not yet opened, got { self .state } ' )
104+
99105 if path is not None :
106+ self .path = path
107+
108+ if self .path is not None :
100109 if self .fout is not None :
101110 self .fout .close ()
102- self .fout = open (path , "wb" )
111+ self .fout = open (self . path , "wb" )
103112 self .state = WriterState .EMPTY
113+
114+ def write_header_to_file (self , path : os .PathLike [str ] | str | None = None ) -> None :
115+ self .open_output_file (path )
116+
104117 if self .state is not WriterState .EMPTY :
105118 raise ValueError (f'Expected output file to be empty, got { self .state } ' )
106119
@@ -206,8 +219,8 @@ def add_tensor_info(
206219 self , name : str , tensor_shape : Sequence [int ], tensor_dtype : np .dtype ,
207220 tensor_nbytes : int , raw_dtype : GGMLQuantizationType | None = None ,
208221 ) -> None :
209- if self .state is not WriterState .EMPTY and self . state is not WriterState . NO_FILE :
210- raise ValueError (f'Expected output file to be empty or absent , got { self .state } ' )
222+ if self .state is not WriterState .NO_FILE :
223+ raise ValueError (f'Expected output file to be not yet opened , got { self .state } ' )
211224
212225 if name in self .tensors :
213226 raise ValueError (f'Duplicated tensor name { name !r} ' )
@@ -263,8 +276,8 @@ def write_padding(self, fp: IO[bytes], n: int, align: int | None = None) -> None
263276 fp .write (bytes ([0 ] * pad ))
264277
265278 def write_tensor_data (self , tensor : np .ndarray [Any , Any ]) -> None :
266- if self .state is not WriterState .TI_DATA :
267- raise ValueError (f'Expected output file to contain tensor info, got { self .state } ' )
279+ if self .state is not WriterState .TI_DATA and self . state is not WriterState . WEIGHTS :
280+ raise ValueError (f'Expected output file to contain tensor info or weights , got { self .state } ' )
268281 assert self .fout is not None
269282
270283 if self .endianess == GGUFEndian .BIG :
@@ -273,6 +286,8 @@ def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
273286 tensor .tofile (self .fout )
274287 self .write_padding (self .fout , tensor .nbytes )
275288
289+ self .state = WriterState .WEIGHTS
290+
276291 def write_tensors_to_file (self , * , progress : bool = False ) -> None :
277292 self .write_ti_data_to_file ()
278293
@@ -299,14 +314,14 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None:
299314 bar .update (ti .nbytes )
300315 self .write_padding (self .fout , ti .nbytes )
301316 ti .tensor = None
317+ else :
318+ self .temp_file .seek (0 )
302319
303- return
304-
305- self .temp_file .seek ( 0 )
320+ shutil . copyfileobj ( self . temp_file , self . fout )
321+ self . flush ()
322+ self .temp_file .close ( )
306323
307- shutil .copyfileobj (self .temp_file , self .fout )
308- self .flush ()
309- self .temp_file .close ()
324+ self .state = WriterState .WEIGHTS
310325
311326 def flush (self ) -> None :
312327 assert self .fout is not None
0 commit comments