-
Notifications
You must be signed in to change notification settings - Fork 12.5k
Support intern-s1 #14875
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Support intern-s1 #14875
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -607,13 +607,14 @@ | |
toktypes: list[int] = [] | ||
|
||
from transformers import AutoTokenizer | ||
tokenizer = AutoTokenizer.from_pretrained(self.dir_model) | ||
vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab)) | ||
assert max(tokenizer.vocab.values()) < vocab_size | ||
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True) | ||
vocab = getattr(tokenizer, 'vocab', tokenizer.get_vocab()) | ||
vocab_size = self.hparams.get("vocab_size", len(vocab)) | ||
assert max(vocab.values()) < vocab_size | ||
|
||
tokpre = self.get_vocab_base_pre(tokenizer) | ||
|
||
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()} | ||
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in vocab.items()} | ||
added_vocab = tokenizer.get_added_vocab() | ||
|
||
added_tokens_decoder = tokenizer.added_tokens_decoder | ||
|
@@ -1218,8 +1219,12 @@ | |
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count) | ||
|
||
# load preprocessor config | ||
with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f: | ||
self.preprocessor_config = json.load(f) | ||
preprocess_config_file = self.dir_model / "preprocessor_config.json" | ||
if preprocess_config_file.exists(): | ||
CISC marked this conversation as resolved.
Show resolved
Hide resolved
|
||
with open(preprocess_config_file, "r", encoding="utf-8") as f: | ||
self.preprocessor_config = json.load(f) | ||
else: | ||
self.preprocessor_config = dict(image_mean=[0.485, 0.456, 0.406], image_std=[0.229, 0.224, 0.225]) | ||
CISC marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def get_vision_config(self) -> dict[str, Any] | None: | ||
return self.global_config.get("vision_config") | ||
|
@@ -2998,7 +3003,12 @@ | |
@ModelBase.register("InternVisionModel") | ||
class InternVisionModel(MmprojModel): | ||
def set_gguf_parameters(self): | ||
if isinstance(self.hparams_vision['image_size'], list): | ||
CISC marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.hparams_vision['image_size'] = self.hparams_vision['image_size'][0] | ||
if isinstance(self.hparams_vision['patch_size'], list): | ||
self.hparams_vision['patch_size'] = self.hparams_vision['patch_size'][0] | ||
super().set_gguf_parameters() | ||
|
||
hparams = self.hparams | ||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.INTERNVL) | ||
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"]) | ||
|
@@ -3022,8 +3032,43 @@ | |
return gguf.GGMLQuantizationType.F32 | ||
return False | ||
|
||
def _mapping_name_interns1(self, name): | ||
names_map = { | ||
"model.multi_modal_projector.layer_norm.bias": "mlp1.0.bias", | ||
"model.multi_modal_projector.layer_norm.weight": "mlp1.0.weight", | ||
"model.multi_modal_projector.linear_1.bias": "mlp1.1.bias", | ||
"model.multi_modal_projector.linear_1.weight": "mlp1.1.weight", | ||
"model.multi_modal_projector.linear_2.bias": "mlp1.3.bias", | ||
"model.multi_modal_projector.linear_2.weight": "mlp1.3.weight", | ||
"model.vision_tower.embeddings.cls_token": "vision_model.embeddings.class_embedding", | ||
"model.vision_tower.embeddings.patch_embeddings.projection.bias": "vision_model.embeddings.patch_embedding.bias", | ||
"model.vision_tower.embeddings.patch_embeddings.projection.weight": "vision_model.embeddings.patch_embedding.weight", | ||
"model.vision_tower.embeddings.position_embeddings": "vision_model.embeddings.position_embedding", | ||
} | ||
if name in names_map: | ||
name = names_map[name] | ||
elif name.startswith("model.language_model."): | ||
name = "language_model.model." + name[len("model.language_model.") :] | ||
elif name.startswith("model.vision_tower."): | ||
name = "vision_model." + name[len("model.vision_tower.") :] | ||
|
||
if name.startswith("vision_model.encoder.layer"): | ||
name = name.replace(r".layer.", r".layers.") | ||
name = name.replace(r".attention.", r".attn.") | ||
name = name.replace(r".attn.q_proj", r".self_attn.q_proj") | ||
name = name.replace(r".attn.k_proj", r".self_attn.k_proj") | ||
name = name.replace(r".attn.v_proj", r".self_attn.v_proj") | ||
name = name.replace(r".projection_layer.", r".proj.") | ||
name = name.replace(r".lambda_1", r".ls1") | ||
name = name.replace(r".lambda_2", r".ls2") | ||
name = name.replace(r".layernorm_before.", r".norm1.") | ||
name = name.replace(r".layernorm_after.", r".norm2.") | ||
return name | ||
CISC marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: | ||
del bid # unused | ||
name = self._mapping_name_interns1(name) | ||
# support interns1 | ||
if name.startswith("vision_model") or name.startswith("mlp"): | ||
# process visual tensors | ||
# correct name | ||
|
@@ -3115,13 +3160,17 @@ | |
|
||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: | ||
# process the experts separately | ||
name = name.replace(r"language_model.", r"") # InternVL | ||
CISC marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if name.startswith("mlp") or name.startswith("vision_model") or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector"): | ||
# skip visual tensors | ||
return [] | ||
if name.find("experts") != -1: | ||
n_experts = self.hparams["num_experts"] | ||
assert bid is not None | ||
|
||
if self._experts is None: | ||
Check failure on line 3171 in convert_hf_to_gguf.py
|
||
self._experts = [{} for _ in range(self.block_count)] | ||
|
||
Check failure on line 3173 in convert_hf_to_gguf.py
|
||
self._experts[bid][name] = data_torch | ||
|
||
if len(self._experts[bid]) >= n_experts * 3: | ||
|
@@ -3168,6 +3217,41 @@ | |
class Qwen3MoeModel(Qwen2MoeModel): | ||
model_arch = gguf.MODEL_ARCH.QWEN3MOE | ||
|
||
def set_vocab(self): | ||
# deal with interns1 | ||
if 'interns1' in f'{self.dir_model}'.lower(): | ||
CISC marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self._set_vocab_interns1() | ||
return | ||
|
||
try: | ||
self._set_vocab_sentencepiece() | ||
except FileNotFoundError: | ||
self._set_vocab_gpt2() | ||
|
||
def _set_vocab_interns1(self): | ||
tokens, toktypes, tokpre = self.get_vocab_base() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This does not work because The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @CISC Hi, thanks for your reminder. Indeed, the intern-s1 tokenizer is special. It bases on Qwen3 bpe tokenizer, and expands with three spm tokenizer models. It uses some regex patterns to match to which sub vocab to use when tokenizing. Don't know how to implement it in llama.cpp. Do you have any suggestion for this special case? THX There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No easy feat I'm afraid, |
||
self.gguf_writer.add_tokenizer_model("gpt2") | ||
self.gguf_writer.add_tokenizer_pre(tokpre) | ||
self.gguf_writer.add_token_list(tokens) | ||
self.gguf_writer.add_token_types(toktypes) | ||
|
||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) | ||
special_tokens_map_file = self.dir_model / 'special_tokens_map.json' | ||
additional_special_tokens = [] | ||
if special_tokens_map_file.is_file(): | ||
with open(special_tokens_map_file, encoding = 'utf-8') as f: | ||
additional_special_tokens = json.load(f).get('additional_special_tokens', []) | ||
tokenizer_cfg_file = self.dir_model / 'special_tokens_map.json' | ||
if tokenizer_cfg_file.is_file(): | ||
with open(tokenizer_cfg_file, encoding = 'utf-8') as f: | ||
added_tokens_decoder = json.load(f).get('added_tokens_decoder', {}) | ||
token2ids_map = {data['content'] : int(token) for token, data in added_tokens_decoder.items() if data['special']} | ||
for token in additional_special_tokens: | ||
if token in token2ids_map: | ||
special_vocab._set_special_token(token, token2ids_map[token]) | ||
special_vocab._set_special_token('eos', 151645) | ||
special_vocab._set_special_token("bos", 151643) | ||
special_vocab.add_to_gguf(self.gguf_writer) | ||
|
||
@ModelBase.register("GPT2LMHeadModel") | ||
class GPT2Model(TextModel): | ||
|
Uh oh!
There was an error while loading. Please reload this page.