Skip to content

Auto and Base model classes in AutoAWQ

View the documentation of the main classes of AutoAWQ models below.

awq.models.auto.AutoAWQForCausalLM

AutoAWQForCausalLM()
Source code in awq/models/auto.py
51
52
53
54
55
def __init__(self):
    raise EnvironmentError(
        "You must instantiate AutoAWQForCausalLM with\n"
        "AutoAWQForCausalLM.from_quantized or AutoAWQForCausalLM.from_pretrained"
    )

from_pretrained classmethod

from_pretrained(model_path, trust_remote_code=True, safetensors=True, device_map='auto', download_kwargs=None, **model_init_kwargs)
PARAMETER DESCRIPTION
model_path

trust_remote_code

DEFAULT: True

safetensors

DEFAULT: True

device_map

DEFAULT: 'auto'

download_kwargs

DEFAULT: None

**model_init_kwargs

DEFAULT: {}

Source code in awq/models/auto.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
@classmethod
def from_pretrained(
    self,
    model_path,
    trust_remote_code=True,
    safetensors=True,
    device_map="auto",
    download_kwargs=None,
    **model_init_kwargs,
) -> BaseAWQForCausalLM:
    model_type = check_and_get_model_type(
        model_path, trust_remote_code, **model_init_kwargs
    )

    return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained(
        model_path,
        model_type,
        trust_remote_code=trust_remote_code,
        safetensors=safetensors,
        device_map=device_map,
        download_kwargs=download_kwargs,
        **model_init_kwargs,
    )

from_quantized classmethod

from_quantized(quant_path, quant_filename='', max_seq_len=2048, trust_remote_code=True, fuse_layers=True, use_exllama=False, use_exllama_v2=False, use_ipex=False, batch_size=1, safetensors=True, device_map='balanced', max_memory=None, offload_folder=None, download_kwargs=None, **config_kwargs)
PARAMETER DESCRIPTION
quant_path

quant_filename

DEFAULT: ''

max_seq_len

DEFAULT: 2048

trust_remote_code

DEFAULT: True

fuse_layers

DEFAULT: True

use_exllama

DEFAULT: False

use_exllama_v2

DEFAULT: False

use_ipex

DEFAULT: False

batch_size

DEFAULT: 1

safetensors

DEFAULT: True

device_map

DEFAULT: 'balanced'

max_memory

DEFAULT: None

offload_folder

DEFAULT: None

download_kwargs

DEFAULT: None

**config_kwargs

DEFAULT: {}

Source code in awq/models/auto.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
@classmethod
def from_quantized(
    self,
    quant_path,
    quant_filename="",
    max_seq_len=2048,
    trust_remote_code=True,
    fuse_layers=True,
    use_exllama=False,
    use_exllama_v2=False,
    use_ipex=False,
    batch_size=1,
    safetensors=True,
    device_map="balanced",
    max_memory=None,
    offload_folder=None,
    download_kwargs=None,
    **config_kwargs,
) -> BaseAWQForCausalLM:
    os.environ["AWQ_BATCH_SIZE"] = str(batch_size)
    model_type = check_and_get_model_type(quant_path, trust_remote_code)

    if config_kwargs.get("max_new_tokens") is not None:
        max_seq_len = config_kwargs["max_new_tokens"]
        logging.warning(
            "max_new_tokens argument is deprecated... gracefully "
            "setting max_seq_len=max_new_tokens."
        )

    return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
        quant_path,
        model_type,
        quant_filename,
        max_seq_len,
        trust_remote_code=trust_remote_code,
        fuse_layers=fuse_layers,
        use_exllama=use_exllama,
        use_exllama_v2=use_exllama_v2,
        use_ipex=use_ipex,
        safetensors=safetensors,
        device_map=device_map,
        max_memory=max_memory,
        offload_folder=offload_folder,
        download_kwargs=download_kwargs,
        **config_kwargs,
    )

awq.models.base.BaseAWQForCausalLM

BaseAWQForCausalLM(model, model_type, is_quantized, config, quant_config, processor)

Bases: Module

The base model for all AutoAWQ models.

PARAMETER DESCRIPTION
model

The pretrained or quantized model.

TYPE: PreTrainedModel

model_type

The model type, found in config.json.

TYPE: str

is_quantized

Indicates if the current model is quantized.

TYPE: bool

config

The config of the model.

TYPE: PretrainedConfig

quant_config

The quantization config of the model.

TYPE: AwqConfig

processor

An optional processor, e.g. for vision models.

TYPE: AutoProcessor

Source code in awq/models/base.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def __init__(
    self,
    model: Annotated[PreTrainedModel, Doc("The pretrained or quantized model.")],
    model_type: Annotated[str, Doc("The model type, found in config.json.")],
    is_quantized: Annotated[
        bool, Doc("Indicates if the current model is quantized.")
    ],
    config: Annotated[PretrainedConfig, Doc("The config of the model.")],
    quant_config: Annotated[
        AwqConfig, Doc("The quantization config of the model.")
    ],
    processor: Annotated[
        AutoProcessor, Doc("An optional processor, e.g. for vision models.")
    ],
):
    """The base model for all AutoAWQ models."""
    super().__init__()
    self.model: PreTrainedModel = model
    self.model_type: str = model_type
    self.is_quantized: bool = is_quantized
    self.search_result = None
    self.config: PretrainedConfig = config
    self.quant_config: AwqConfig = quant_config
    self.processor: CLIPImageProcessor = processor

model instance-attribute

model = model

model_type instance-attribute

model_type = model_type

is_quantized instance-attribute

is_quantized = is_quantized

search_result instance-attribute

search_result = None

config instance-attribute

config = config

quant_config instance-attribute

quant_config = quant_config

processor instance-attribute

processor = processor

to

to(device)

A utility function for moving the model to a device.

PARAMETER DESCRIPTION
device

The device to move your model to.

TYPE: str

Source code in awq/models/base.py
116
117
118
def to(self, device: Annotated[str, Doc("The device to move your model to.")]):
    """A utility function for moving the model to a device."""
    return self.model.to(device)

forward

forward(*args, **kwargs)

A forward function that mimics the torch forward.

PARAMETER DESCRIPTION
*args

DEFAULT: ()

**kwargs

DEFAULT: {}

Source code in awq/models/base.py
120
121
122
def forward(self, *args, **kwargs):
    """A forward function that mimics the torch forward."""
    return self.model(*args, **kwargs)

generate

generate(*args, **kwargs)

A generate function that mimics the HF generate function.

PARAMETER DESCRIPTION
*args

DEFAULT: ()

**kwargs

DEFAULT: {}

Source code in awq/models/base.py
124
125
126
127
def generate(self, *args, **kwargs):
    """A generate function that mimics the HF generate function."""
    with torch.inference_mode():
        return self.model.generate(*args, **kwargs)

quantize

quantize(tokenizer=None, quant_config={}, calib_data='pileval', split='train', text_column='text', duo_scaling=True, export_compatible=False, apply_clip=True, n_parallel_calib_samples=None, max_calib_samples=128, max_calib_seq_len=512, max_chunk_memory=1024 * 1024 * 1024)

The main quantization function that you can use to quantize your model.

Example:

from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer

model_path = "..."
model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)

quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }
model.quantize(tokenizer, quant_config)
PARAMETER DESCRIPTION
tokenizer

The tokenizer to use for quantization.

TYPE: PreTrainedTokenizer DEFAULT: None

quant_config

The quantization config you want to use.

TYPE: Dict DEFAULT: {}

calib_data

The calibration dataset. Either a string pointing to Huggingface or a list of preloaded examples.

TYPE: Union[str, List[str]] DEFAULT: 'pileval'

split

The split of calib_data.

TYPE: str DEFAULT: 'train'

text_column

The text column of calib_data.

TYPE: str DEFAULT: 'text'

duo_scaling

Whether to scale using both w/x or just x.

TYPE: bool DEFAULT: True

export_compatible

This argument avoids real quantization by only applying the scales without quantizing down to FP16.

TYPE: bool DEFAULT: False

apply_clip

Whether to apply clipping to the model during quantization. Some models may perform better with this set to False.

TYPE: bool DEFAULT: True

n_parallel_calib_samples

The number of parallel samples to run through the model. A high number of parallel samples can result in OOM during quantization if max_calib_samples is high enough. If None, runs through all samples at the same time. You can set this to a low number for more memory efficient quantization.

TYPE: int DEFAULT: None

max_calib_samples

The maximum number of samples to run through the model.

TYPE: int DEFAULT: 128

max_calib_seq_len

The maximum sequence length of the calibration dataset. Discard samples greater than max_calib_seq_len.

TYPE: int DEFAULT: 512

max_chunk_memory

The loss computation and per-channel mean is optimized into chunked computations. Adjust this parameter to increase or decrease memory usage for these computations. Default is 1GB (1024 * 1024 * 1024).

TYPE: int DEFAULT: 1024 * 1024 * 1024

Source code in awq/models/base.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
@torch.no_grad()
def quantize(
    self,
    tokenizer: Annotated[
        PreTrainedTokenizer, Doc("The tokenizer to use for quantization.")
    ] = None,
    quant_config: Annotated[
        Dict, Doc("The quantization config you want to use.")
    ] = {},
    calib_data: Annotated[
        Union[str, List[str]],
        Doc(
            "The calibration dataset. Either a string pointing to Huggingface or a list of preloaded examples."
        ),
    ] = "pileval",
    split: Annotated[str, Doc("The split of calib_data.")] = "train",
    text_column: Annotated[str, Doc("The text column of calib_data.")] = "text",
    duo_scaling: Annotated[
        bool, Doc("Whether to scale using both w/x or just x.")
    ] = True,
    export_compatible: Annotated[
        bool,
        Doc(
            "This argument avoids real quantization by only applying the scales without quantizing down to FP16."
        ),
    ] = False,
    apply_clip: Annotated[
        bool,
        Doc(
            "Whether to apply clipping to the model during quantization. Some models may perform better with this set to False."
        ),
    ] = True,
    n_parallel_calib_samples: Annotated[
        int,
        Doc(
            "The number of parallel samples to run through the model. "
            "A high number of parallel samples can result in OOM during quantization if max_calib_samples is high enough. "
            "If None, runs through all samples at the same time. "
            "You can set this to a low number for more memory efficient quantization."
        ),
    ] = None,
    max_calib_samples: Annotated[
        int, Doc("The maximum number of samples to run through the model.")
    ] = 128,
    max_calib_seq_len: Annotated[
        int,
        Doc(
            "The maximum sequence length of the calibration dataset. Discard samples greater than max_calib_seq_len."
        ),
    ] = 512,
    max_chunk_memory: Annotated[
        int,
        Doc(
            "The loss computation and per-channel mean is optimized into chunked computations."
            " Adjust this parameter to increase or decrease memory usage for these computations."
            " Default is 1GB (1024 * 1024 * 1024)."
        ),
    ] = 1024
    * 1024
    * 1024,
):
    """
    The main quantization function that you can use to quantize your model.

    Example:

    ```python
    from awq import AutoAWQForCausalLM
    from transformers import AutoTokenizer

    model_path = "..."
    model = AutoAWQForCausalLM.from_pretrained(model_path)
    tokenizer = AutoTokenizer.from_pretrained(model_path)

    quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }
    model.quantize(tokenizer, quant_config)
    ```
    """
    self.quant_config: AwqConfig = AwqConfig.from_dict(quant_config)

    if hasattr(self, "modules_to_not_convert"):
        self.quant_config.modules_to_not_convert = self.modules_to_not_convert

    self.quantizer = AwqQuantizer(
        self,
        self.model,
        tokenizer,
        self.quant_config.w_bit,
        self.quant_config.q_group_size,
        self.quant_config.zero_point,
        self.quant_config.version,
        calib_data,
        split,
        text_column,
        duo_scaling,
        modules_to_not_convert=self.quant_config.modules_to_not_convert,
        export_compatible=export_compatible,
        apply_clip=apply_clip,
        n_parallel_calib_samples=n_parallel_calib_samples,
        max_calib_samples=max_calib_samples,
        max_calib_seq_len=max_calib_seq_len,
        max_chunk_memory=max_chunk_memory,
    )
    self.quantizer.quantize()

    self.is_quantized = True

pack

pack()

A utility function for the following scenario. Note that save_quantized will overwrite existing weights if you use the same quant_path.

Example:

model.quantize(
    tokenizer,
    quant_config=quant_config,
    export_compatible=True
)
model.save_quantized(...)  # produces GGUF/other compat weights
model.pack(...) # makes the model CUDA compat
model.save_quantized(...)  # produces CUDA compat weights
Source code in awq/models/base.py
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
@torch.no_grad()
def pack(self):
    """
    A utility function for the following scenario. Note that save_quantized will
    overwrite existing weights if you use the same quant_path.

    Example:

    ```python
    model.quantize(
        tokenizer,
        quant_config=quant_config,
        export_compatible=True
    )
    model.save_quantized(...)  # produces GGUF/other compat weights
    model.pack(...) # makes the model CUDA compat
    model.save_quantized(...)  # produces CUDA compat weights
    ```
    """
    self.quantizer.pack()

fuse_layers staticmethod

fuse_layers(model)
PARAMETER DESCRIPTION
model

Source code in awq/models/base.py
257
258
259
@staticmethod
def fuse_layers(model):
    pass

save_quantized

save_quantized(save_dir, safetensors=True, shard_size='5GB')
PARAMETER DESCRIPTION
save_dir

The directory to save your model to.

TYPE: str

safetensors

Whether to save the model as safetensors or torch files.

TYPE: bool DEFAULT: True

shard_size

The shard size for sharding large models into multiple chunks.

TYPE: str DEFAULT: '5GB'

Source code in awq/models/base.py
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
def save_quantized(
    self,
    save_dir: Annotated[str, Doc("The directory to save your model to.")],
    safetensors: Annotated[
        bool, Doc("Whether to save the model as safetensors or torch files.")
    ] = True,
    shard_size: Annotated[
        str, Doc("The shard size for sharding large models into multiple chunks.")
    ] = "5GB",
):
    save_dir = save_dir[:-1] if save_dir[-1] == "/" else save_dir

    # Save model
    class EmptyModule(nn.Module):
        def __init__(self):
            super(EmptyModule, self).__init__()

        def forward(self, x):
            return x

    # Save model and config files with empty state dict
    self.model.config.quantization_config = self.quant_config.to_transformers_dict()
    self.model.generation_config.do_sample = True
    self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())

    # Vision transformers have a processor
    if self.processor is not None:
        self.processor.save_pretrained(save_dir)

    # Remove empty state dict
    default_paths = [
        f"{save_dir}/model.safetensors",
        f"{save_dir}/pytorch_model.bin",
    ]
    for path in default_paths:
        if os.path.exists(path):
            os.remove(path)

    # model_name has no extension, add it when saving state_dict
    model_name = "model.safetensors" if safetensors else "pytorch_model.bin"

    # shard checkpoint into chunks (10GB default)
    shards, index = shard_checkpoint(
        self.model.state_dict(), max_shard_size=shard_size, weights_name=model_name
    )

    for shard_file, shard in shards.items():
        if safetensors:
            # safetensors must be in the same memory, so we duplicate and use contiguous memory
            shard = {k: v.clone().contiguous() for k, v in shard.items()}
            save_file(
                shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"}
            )
        else:
            torch.save(shard, os.path.join(save_dir, shard_file))

    # save shard index
    if index is not None:
        with open(f"{save_dir}/{model_name}.index.json", "w+") as file:
            file.write(json.dumps(index, indent=4))

from_pretrained classmethod

from_pretrained(model_path, model_type, torch_dtype=torch.float16, trust_remote_code=True, safetensors=True, device_map='auto', download_kwargs=None, **model_init_kwargs)

A method for initialization of pretrained models, usually in FP16.

PARAMETER DESCRIPTION
model_path

A Huggingface path or local path to a model.

TYPE: str

model_type

The model type, loaded from config.json.

TYPE: str

torch_dtype

The dtype to load the model as. May not work with other values than float16.

TYPE: dtype DEFAULT: float16

trust_remote_code

Useful for Huggingface repositories that have not been integrated into transformers yet.

TYPE: bool DEFAULT: True

safetensors

Whether to download/load safetensors instead of torch weights.

TYPE: bool DEFAULT: True

device_map

A device map that will be passed onto the model loading method from transformers.

TYPE: Union[str, Dict] DEFAULT: 'auto'

download_kwargs

Used for configure download model

TYPE: Dict DEFAULT: None

**model_init_kwargs

Additional kwargs that are passed to the model during initialization.

TYPE: Dict DEFAULT: {}

Source code in awq/models/base.py
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
@classmethod
def from_pretrained(
    self,
    model_path: Annotated[str, Doc("A Huggingface path or local path to a model.")],
    model_type: Annotated[str, Doc("The model type, loaded from config.json.")],
    torch_dtype: Annotated[
        torch.dtype,
        Doc(
            "The dtype to load the model as. May not work with other values than float16."
        ),
    ] = torch.float16,
    trust_remote_code: Annotated[
        bool,
        Doc(
            "Useful for Huggingface repositories that have not been integrated into transformers yet."
        ),
    ] = True,
    safetensors: Annotated[
        bool, Doc("Whether to download/load safetensors instead of torch weights.")
    ] = True,
    device_map: Annotated[
        Union[str, Dict],
        Doc(
            "A device map that will be passed onto the model loading method from transformers."
        ),
    ] = "auto",
    download_kwargs: Annotated[
        Dict,
        Doc("Used for configure download model"),
    ] = None,
    **model_init_kwargs: Annotated[
        Dict,
        Doc(
            "Additional kwargs that are passed to the model during initialization."
        ),
    ],
):
    """A method for initialization of pretrained models, usually in FP16."""
    # Get weights path and quant config
    model_weights_path, config, quant_config = self._load_config(
        self,
        model_path,
        "",
        safetensors,
        trust_remote_code=trust_remote_code,
        download_kwargs=download_kwargs,
    )

    target_cls_name = TRANSFORMERS_AUTO_MAPPING_DICT[config.model_type]
    target_cls = getattr(transformers, target_cls_name)

    processor = None
    if target_cls_name == "AutoModelForVision2Seq":
        processor = AutoProcessor.from_pretrained(model_weights_path)
        processor: CLIPImageProcessor = processor.image_processor

    # If not quantized, must load with AutoModelForCausalLM
    model = target_cls.from_pretrained(
        model_weights_path,
        trust_remote_code=trust_remote_code,
        torch_dtype=torch_dtype,
        use_safetensors=safetensors,
        device_map=device_map,
        **model_init_kwargs,
    )

    model.eval()

    return self(
        model,
        model_type,
        is_quantized=False,
        config=config,
        quant_config=quant_config,
        processor=processor,
    )

from_quantized classmethod

from_quantized(model_path, model_type, model_filename='', max_seq_len=None, torch_dtype=torch.float16, trust_remote_code=True, safetensors=True, fuse_layers=True, use_exllama=False, use_exllama_v2=False, use_ipex=False, device_map='balanced', max_memory=None, offload_folder=None, download_kwargs=None, **config_kwargs)

A method for initialization of a quantized model, usually in INT4.

PARAMETER DESCRIPTION
model_path

A Huggingface path or local path to a model.

TYPE: str

model_type

The model type, loaded from config.json.

TYPE: str

model_filename

Load a specific model's filename by specifying this argument.

TYPE: str DEFAULT: ''

max_seq_len

The maximum sequence cached sequence length of the model. Larger values may increase loading time and memory usage.

TYPE: int DEFAULT: None

torch_dtype

The dtype to load the model as. May not work with other values than float16.

TYPE: dtype DEFAULT: float16

trust_remote_code

Useful for Huggingface repositories that have not been integrated into transformers yet.

TYPE: bool DEFAULT: True

safetensors

Whether to download/load safetensors instead of torch weights.

TYPE: bool DEFAULT: True

fuse_layers

Whether to use fused/optimized combination of layers for increased speed.

TYPE: bool DEFAULT: True

use_exllama

Whether to map the weights to ExLlamaV1 kernels.

TYPE: bool DEFAULT: False

use_exllama_v2

Whether to map the weights to ExLlamaV2 kernels.

TYPE: bool DEFAULT: False

use_ipex

Whether to map the weights to ipex kernels for CPU device.

TYPE: bool DEFAULT: False

device_map

A device map that will be passed onto the model loading method from transformers.

TYPE: Union[str, Dict] DEFAULT: 'balanced'

max_memory

A dictionary device identifier to maximum memory which will be passed onto the model loading method from transformers. For example:{0: "4GB",1: "10GB"

TYPE: Dict[Union[int, str], Union[int, str]] DEFAULT: None

offload_folder

The folder ot offload the model to.

TYPE: str DEFAULT: None

download_kwargs

Used for configure download model

TYPE: Dict DEFAULT: None

**config_kwargs

Additional kwargs that are passed to the config during initialization.

TYPE: Dict DEFAULT: {}

Source code in awq/models/base.py
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
@classmethod
def from_quantized(
    self,
    model_path: Annotated[str, Doc("A Huggingface path or local path to a model.")],
    model_type: Annotated[str, Doc("The model type, loaded from config.json.")],
    model_filename: Annotated[
        str, Doc("Load a specific model's filename by specifying this argument.")
    ] = "",
    max_seq_len: Annotated[
        int,
        Doc(
            "The maximum sequence cached sequence length of the model. Larger values may increase loading time and memory usage."
        ),
    ] = None,
    torch_dtype: Annotated[
        torch.dtype,
        Doc(
            "The dtype to load the model as. May not work with other values than float16."
        ),
    ] = torch.float16,
    trust_remote_code: Annotated[
        bool,
        Doc(
            "Useful for Huggingface repositories that have not been integrated into transformers yet."
        ),
    ] = True,
    safetensors: Annotated[
        bool, Doc("Whether to download/load safetensors instead of torch weights.")
    ] = True,
    fuse_layers: Annotated[
        bool,
        Doc(
            "Whether to use fused/optimized combination of layers for increased speed."
        ),
    ] = True,
    use_exllama: Annotated[
        bool, Doc("Whether to map the weights to ExLlamaV1 kernels.")
    ] = False,
    use_exllama_v2: Annotated[
        bool, Doc("Whether to map the weights to ExLlamaV2 kernels.")
    ] = False,
    use_ipex: Annotated[
        bool, Doc("Whether to map the weights to ipex kernels for CPU device.")
    ] = False,
    device_map: Annotated[
        Union[str, Dict],
        Doc(
            "A device map that will be passed onto the model loading method from transformers."
        ),
    ] = "balanced",
    max_memory: Annotated[
        Dict[Union[int, str], Union[int, str]],
        Doc(
            'A dictionary device identifier to maximum memory which will be passed onto the model loading method from transformers. For example:{0: "4GB",1: "10GB"'
        ),
    ] = None,
    offload_folder: Annotated[
        str,
        Doc("The folder ot offload the model to."),
    ] = None,
    download_kwargs: Annotated[
        Dict,
        Doc("Used for configure download model"),
    ] = None,
    **config_kwargs: Annotated[
        Dict,
        Doc(
            "Additional kwargs that are passed to the config during initialization."
        ),
    ],
):
    """A method for initialization of a quantized model, usually in INT4."""
    # [STEP 1-2] Load weights path and configs
    model_weights_path, config, quant_config = self._load_config(
        self,
        model_path,
        model_filename,
        safetensors,
        trust_remote_code,
        max_seq_len=max_seq_len,
        download_kwargs=download_kwargs,
        **config_kwargs,
    )

    target_cls_name = TRANSFORMERS_AUTO_MAPPING_DICT[config.model_type]
    target_cls = getattr(transformers, target_cls_name)

    # [STEP 3] Load model
    with init_empty_weights():
        model = target_cls.from_config(
            config=config,
            torch_dtype=torch_dtype,
            trust_remote_code=trust_remote_code,
        )

    use_cpu_ipex = use_ipex or get_best_device() == "cpu"
    if use_cpu_ipex and not ipex_available:
        raise ImportError(
            "Please install intel_extension_for_pytorch with "
            "`pip install intel_extension_for_pytorch` for 'ipex' kernel!"
        )
    # Prepare WQLinear layers, replace nn.Linear
    self._load_quantized_modules(
        self,
        model,
        quant_config,
        quant_config.version,
        use_exllama=use_exllama,
        use_exllama_v2=use_exllama_v2,
        use_ipex=use_cpu_ipex,
    )

    model.tie_weights()

    # loads the weights into modules and distributes
    # across available devices automatically
    load_checkpoint_and_dispatch(
        model,
        checkpoint=model_weights_path,
        device_map=device_map,
        max_memory=max_memory,
        no_split_module_classes=[self.layer_type],
        offload_folder=offload_folder,
        dtype=torch_dtype,
    )

    # Dispath to devices
    awq_ext, msg = try_import("awq_ext")
    if fuse_layers:
        if awq_ext is None:
            warnings.warn("Skipping fusing modules because AWQ extension is not installed." + msg)
        else:
            self.fuse_layers(model)

    if use_cpu_ipex:
        dtype = torch.bfloat16
        model.to(dtype=dtype, device="cpu")
        # repack qweight to match the ipex kernel.
        model = ipex_post_init(model)
    elif quant_config.version == "marlin":
        model = marlin_post_init(model)
    elif use_exllama:
        # creates q4 handle
        model = exllama_post_init(model)
    elif use_exllama_v2:
        # creates q4 handle and allocates scratch spaces wrt max_input_len and max_batch_size
        model = exllamav2_post_init(
            model,
            max_input_len=max_seq_len or 2048,
            max_batch_size=int(os.getenv("AWQ_BATCH_SIZE", 1)),
        )

    model.eval()

    return self(
        model,
        model_type,
        is_quantized=True,
        config=config,
        quant_config=quant_config,
        processor=None,
    )