1212#include < cassert>
1313#include < cstring>
1414
15+ // headers for POSIX mmap
16+ #if defined (__unix__) || defined (__APPLE__)
17+ # include < sys/mman.h>
18+ # include < fcntl.h>
19+ # include < unistd.h>
20+ #endif
21+
1522#define LLAMA_USE_SCRATCH
1623#define LLAMA_MAX_SCRATCH_BUFFERS 16
1724
@@ -246,6 +253,7 @@ static bool kv_cache_init(
246253 struct ggml_init_params params;
247254 params.mem_size = cache.buf .size ();
248255 params.mem_buffer = cache.buf .data ();
256+ params.no_alloc = false ;
249257
250258 cache.ctx = ggml_init (params);
251259
@@ -288,6 +296,26 @@ struct llama_context_params llama_context_default_params() {
288296// model loading
289297//
290298
299+ void * mmap_file (const char * fname) {
300+ #if defined(MAP_FAILED)
301+ // POSIX mmap
302+ int fd = open (fname, O_RDONLY);
303+ size_t len = lseek (fd, 0 , SEEK_END);
304+ void * mm_addr = mmap (NULL , len, PROT_READ, MAP_SHARED, fd, 0 );
305+ if (mm_addr == MAP_FAILED) {
306+ perror (" mmap failed" );
307+ mm_addr = NULL ;
308+ }
309+ close (fd);
310+ return mm_addr;
311+ #else
312+ // TODO: windows support
313+ (void )(fname); // suppress warnings
314+ return NULL ;
315+ #endif
316+ }
317+
318+
291319static bool llama_model_load (
292320 const std::string & fname,
293321 llama_context & lctx,
@@ -303,6 +331,7 @@ static bool llama_model_load(
303331
304332 lctx.t_start_us = t_start_us;
305333
334+ // TODO: this could probably be smaller when using mmap
306335 std::vector<char > f_buf (1024 *1024 );
307336
308337 auto & model = lctx.model ;
@@ -449,39 +478,49 @@ static bool llama_model_load(
449478 }
450479 }
451480
481+ bool use_mmap = (n_parts == 1 );
482+
483+ // try to memory map the model file
484+ void * mm_addr = NULL ;
485+ if (use_mmap) {
486+ mm_addr = mmap_file (fname.c_str ());
487+ if (mm_addr == NULL ) {
488+ use_mmap = false ;
489+ }
490+ }
491+
492+
493+
452494 auto & ctx = model.ctx ;
453495
454496 size_t ctx_size = 0 ;
455-
456497 {
457498 const auto & hparams = model.hparams ;
458499
459500 const int n_embd = hparams.n_embd ;
460501 const int n_layer = hparams.n_layer ;
461- const int n_ctx = hparams.n_ctx ;
462502 const int n_vocab = hparams.n_vocab ;
463503
464- ctx_size += n_embd*n_vocab*ggml_type_sizef (vtype); // tok_embeddings
504+ if (!use_mmap) {
505+ ctx_size += n_embd*n_vocab*ggml_type_sizef (vtype); // tok_embeddings
465506
466- ctx_size += n_embd*ggml_type_sizef (GGML_TYPE_F32); // norm
507+ ctx_size += n_embd*ggml_type_sizef (GGML_TYPE_F32); // norm
467508
468- ctx_size += n_embd*n_vocab*ggml_type_sizef (vtype); // output
509+ ctx_size += n_embd*n_vocab*ggml_type_sizef (vtype); // output
469510
470- ctx_size += n_layer*(n_embd*ggml_type_sizef (GGML_TYPE_F32)); // attention_norm
511+ ctx_size += n_layer*(n_embd*ggml_type_sizef (GGML_TYPE_F32)); // attention_norm
471512
472- ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef (wtype)); // wq
473- ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef (wtype)); // wk
474- ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef (wtype)); // wv
475- ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef (wtype)); // wo
513+ ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef (wtype)); // wq
514+ ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef (wtype)); // wk
515+ ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef (wtype)); // wv
516+ ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef (wtype)); // wo
476517
477- ctx_size += n_layer*(n_embd*ggml_type_sizef (GGML_TYPE_F32)); // ffn_norm
518+ ctx_size += n_layer*(n_embd*ggml_type_sizef (GGML_TYPE_F32)); // ffn_norm
478519
479- ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef (wtype)); // w1
480- ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef (wtype)); // w2
481- ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef (wtype)); // w3
482-
483- ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef (memory_type); // memory_k
484- ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef (memory_type); // memory_v
520+ ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef (wtype)); // w1
521+ ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef (wtype)); // w2
522+ ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef (wtype)); // w3
523+ }
485524
486525 ctx_size += (5 + 10 *n_layer)*256 ; // object overhead
487526
@@ -514,6 +553,7 @@ static bool llama_model_load(
514553 struct ggml_init_params params = {
515554 /* .mem_size =*/ lctx.model .buf .size (),
516555 /* .mem_buffer =*/ lctx.model .buf .data (),
556+ /* .no_alloc =*/ use_mmap,
517557 };
518558
519559 model.ctx = ggml_init (params);
@@ -595,7 +635,7 @@ static bool llama_model_load(
595635 fname_part += " ." + std::to_string (i);
596636 }
597637
598- fprintf (stderr, " %s: loading model part %d/%d from '%s'\n " , __func__, i+1 , n_parts, fname_part.c_str ());
638+ fprintf (stderr, " %s: loading model part %d/%d from '%s'%s \n " , __func__, i+1 , n_parts, fname_part.c_str (), use_mmap ? " (memory mapped) " : " " );
599639
600640 fin = std::ifstream (fname_part, std::ios::binary);
601641 fin.rdbuf ()->pubsetbuf (f_buf.data (), f_buf.size ());
@@ -736,7 +776,14 @@ static bool llama_model_load(
736776 }
737777
738778 if (part_id == 0 ) {
739- fin.read (reinterpret_cast <char *>(tensor->data ), ggml_nbytes (tensor));
779+ if (mm_addr) {
780+ off_t offset = fin.tellg ();
781+ tensor->data = (char *) mm_addr + offset;
782+ fin.seekg (ggml_nbytes (tensor), std::ios::cur);
783+ }
784+ else {
785+ fin.read (reinterpret_cast <char *>(tensor->data ), ggml_nbytes (tensor));
786+ }
740787 } else {
741788 fin.seekg (ggml_nbytes (tensor), std::ios::cur);
742789 }
@@ -849,6 +896,7 @@ static bool llama_eval_internal(
849896 struct ggml_init_params params = {
850897 /* .mem_size =*/ buf_compute.size (),
851898 /* .mem_buffer =*/ buf_compute.data (),
899+ /* .no_alloc =*/ false ,
852900 };
853901
854902 struct ggml_context * ctx0 = ggml_init (params);
0 commit comments