cksghl1004 commited on
Commit
e98a1bd
·
verified ·
1 Parent(s): b01b7f9

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ moondream2-mmproj-f16.gguf filter=lfs diff=lfs merge=lfs -text
37
+ moondream2-text-model-f16.gguf filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ pipeline_tag: image-text-to-text
4
+ ---
5
+
6
+ Moondream is a small vision language model designed to run efficiently everywhere.
7
+
8
+ [Website](https://moondream.ai/) / [Demo](https://moondream.ai/playground) / [GitHub](https://github.com/vikhyat/moondream)
9
+
10
+ This repository contains the latest (**2025-06-21**) release of Moondream, as well as [historical releases](https://huggingface.co/vikhyatk/moondream2/blob/main/versions.txt). The model is updated frequently, so we recommend specifying a revision as shown below if you're using it in a production application.
11
+
12
+
13
+ ### Usage
14
+
15
+ ```python
16
+ from transformers import AutoModelForCausalLM, AutoTokenizer
17
+ from PIL import Image
18
+
19
+ model = AutoModelForCausalLM.from_pretrained(
20
+ "vikhyatk/moondream2",
21
+ revision="2025-06-21",
22
+ trust_remote_code=True,
23
+ device_map={"": "cuda"} # ...or 'mps', on Apple Silicon
24
+ )
25
+
26
+ # Captioning
27
+ print("Short caption:")
28
+ print(model.caption(image, length="short")["caption"])
29
+
30
+ print("\nNormal caption:")
31
+ for t in model.caption(image, length="normal", stream=True)["caption"]:
32
+ # Streaming generation example, supported for caption() and detect()
33
+ print(t, end="", flush=True)
34
+ print(model.caption(image, length="normal"))
35
+
36
+ # Visual Querying
37
+ print("\nVisual query: 'How many people are in the image?'")
38
+ print(model.query(image, "How many people are in the image?")["answer"])
39
+
40
+ # Object Detection
41
+ print("\nObject detection: 'face'")
42
+ objects = model.detect(image, "face")["objects"]
43
+ print(f"Found {len(objects)} face(s)")
44
+
45
+ # Pointing
46
+ print("\nPointing: 'person'")
47
+ points = model.point(image, "person")["points"]
48
+ print(f"Found {len(points)} person(s)")
49
+ ```
50
+
51
+ ### Changelog
52
+
53
+ **2025-06-21**
54
+
55
+ (release notes coming soon)
56
+
57
+ **2025-04-15** ([full release notes](https://moondream.ai/blog/moondream-2025-04-14-release))
58
+
59
+ 1. Improved chart understanding (ChartQA up from 74.8 to 77.5, 82.2 with PoT)
60
+ 2. Added temperature and nucleus sampling to reduce repetitive outputs
61
+ 3. Better OCR for documents and tables (prompt with “Transcribe the text” or “Transcribe the text in natural reading order”)
62
+ 4. Object detection supports document layout detection (figure, formula, text, etc)
63
+ 5. UI understanding (ScreenSpot F1\@0.5 up from 53.3 to 60.3)
64
+ 6. Improved text understanding (DocVQA up from 76.5 to 79.3, TextVQA up from 74.6 to 76.3)
65
+
66
+ **2025-03-27** ([full release notes](https://moondream.ai/blog/moondream-2025-03-27-release))
67
+
68
+ 1. Added support for long-form captioning
69
+ 2. Open vocabulary image tagging
70
+ 3. Improved counting accuracy (e.g. CountBenchQA increased from 80 to 86.4)
71
+ 4. Improved text understanding (e.g. OCRBench increased from 58.3 to 61.2)
72
+ 5. Improved object detection, especially for small objects (e.g. COCO up from 30.5 to 51.2)
73
+ 6. Fixed token streaming bug affecting multi-byte unicode characters
74
+ 7. gpt-fast style `compile()` now supported in HF Transformers implementation
added_tokens.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "\t\t": 50294,
3
+ "\t\t\t": 50293,
4
+ "\t\t\t\t": 50292,
5
+ "\t\t\t\t\t": 50291,
6
+ "\t\t\t\t\t\t": 50290,
7
+ "\t\t\t\t\t\t\t": 50289,
8
+ "\t\t\t\t\t\t\t\t": 50288,
9
+ "\t\t\t\t\t\t\t\t\t": 50287,
10
+ " ": 50286,
11
+ " ": 50285,
12
+ " ": 50284,
13
+ " ": 50283,
14
+ " ": 50282,
15
+ " ": 50281,
16
+ " ": 50280,
17
+ " ": 50279,
18
+ " ": 50278,
19
+ " ": 50277,
20
+ " ": 50276,
21
+ " ": 50275,
22
+ " ": 50274,
23
+ " ": 50273,
24
+ " ": 50272,
25
+ " ": 50271,
26
+ " ": 50270,
27
+ " ": 50269,
28
+ " ": 50268,
29
+ " ": 50267,
30
+ " ": 50266,
31
+ " ": 50265,
32
+ " ": 50264,
33
+ " ": 50263,
34
+ " ": 50262,
35
+ " ": 50261,
36
+ " ": 50260,
37
+ " ": 50259,
38
+ " ": 50258,
39
+ " ": 50257
40
+ }
config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "HfMoondream"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "hf_moondream.HfConfig",
7
+ "AutoModelForCausalLM": "hf_moondream.HfMoondream"
8
+ },
9
+ "config": {},
10
+ "model_type": "moondream1",
11
+ "torch_dtype": "bfloat16",
12
+ "transformers_version": "4.52.4"
13
+ }
config.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Dict, List, Optional
3
+
4
+
5
+ @dataclass(frozen=True)
6
+ class TextConfig:
7
+ dim: int = 2048
8
+ ff_dim: int = 8192
9
+ n_layers: int = 24
10
+ vocab_size: int = 51200
11
+ max_context: int = 2048
12
+ n_heads: int = 32
13
+ n_kv_heads: int = 32
14
+ prefix_attn: int = 730
15
+ group_size: Optional[int] = None
16
+
17
+
18
+ @dataclass(frozen=True)
19
+ class VisionConfig:
20
+ enc_dim: int = 1152
21
+ enc_patch_size: int = 14
22
+ enc_n_layers: int = 27
23
+ enc_ff_dim: int = 4304
24
+ enc_n_heads: int = 16
25
+ proj_out_dim: int = 2048
26
+ crop_size: int = 378
27
+ in_channels: int = 3
28
+ max_crops: int = 12
29
+ overlap_margin: int = 4
30
+ proj_inner_dim: int = 8192
31
+
32
+
33
+ @dataclass(frozen=True)
34
+ class RegionConfig:
35
+ dim: int = 2048
36
+ coord_feat_dim: int = 256
37
+ coord_out_dim: int = 1024
38
+ size_feat_dim: int = 512
39
+ size_out_dim: int = 2048
40
+ inner_dim: int = 8192
41
+ group_size: Optional[int] = None
42
+
43
+
44
+ @dataclass(frozen=True)
45
+ class TokenizerConfig:
46
+ bos_id: int = 0
47
+ eos_id: int = 0
48
+ answer_id: int = 3
49
+ thinking_id: int = 4
50
+ coord_id: int = 5
51
+ size_id: int = 6
52
+ start_ground_points_id: int = 7
53
+ end_ground_id: int = 9
54
+ templates: Dict[str, Optional[Dict[str, List[int]]]] = field(
55
+ default_factory=lambda: {
56
+ "caption": {
57
+ "short": [1, 32708, 2, 12492, 3],
58
+ "normal": [1, 32708, 2, 6382, 3],
59
+ "long": [1, 32708, 2, 4059, 3],
60
+ },
61
+ "query": {"prefix": [1, 15381, 2], "suffix": [3]},
62
+ "detect": {"prefix": [1, 7235, 476, 2], "suffix": [3]},
63
+ "point": {"prefix": [1, 2581, 2], "suffix": [3]},
64
+ }
65
+ )
66
+
67
+
68
+ @dataclass(frozen=True)
69
+ class MoondreamConfig:
70
+ text: TextConfig = TextConfig()
71
+ vision: VisionConfig = VisionConfig()
72
+ region: RegionConfig = RegionConfig()
73
+ tokenizer: TokenizerConfig = TokenizerConfig()
74
+
75
+ @classmethod
76
+ def from_dict(cls, config_dict: dict):
77
+ text_config = TextConfig(**config_dict.get("text", {}))
78
+ vision_config = VisionConfig(**config_dict.get("vision", {}))
79
+ region_config = RegionConfig(**config_dict.get("region", {}))
80
+ tokenizer_config = TokenizerConfig(**config_dict.get("tokenizer", {}))
81
+ return cls(
82
+ text=text_config,
83
+ vision=vision_config,
84
+ region=region_config,
85
+ tokenizer=tokenizer_config,
86
+ )
87
+
88
+ def to_dict(self):
89
+ return {
90
+ "text": self.text.__dict__,
91
+ "vision": self.vision.__dict__,
92
+ "region": self.region.__dict__,
93
+ "tokenizer": self.tokenizer.__dict__,
94
+ }
configuration_moondream.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class PhiConfig(PretrainedConfig):
5
+ model_type = "phi"
6
+ keys_to_ignore_at_inference = ["past_key_values"]
7
+
8
+ def __init__(
9
+ self,
10
+ vocab_size=51200,
11
+ hidden_size=2048,
12
+ intermediate_size=8192,
13
+ num_hidden_layers=24,
14
+ num_attention_heads=32,
15
+ num_key_value_heads=None,
16
+ resid_pdrop=0.0,
17
+ embd_pdrop=0.0,
18
+ attention_dropout=0.0,
19
+ hidden_act="gelu_new",
20
+ max_position_embeddings=2048,
21
+ initializer_range=0.02,
22
+ layer_norm_eps=1e-5,
23
+ use_cache=True,
24
+ tie_word_embeddings=False,
25
+ rope_theta=10000.0,
26
+ rope_scaling=None,
27
+ partial_rotary_factor=0.5,
28
+ bos_token_id=1,
29
+ eos_token_id=2,
30
+ **kwargs,
31
+ ):
32
+ self.vocab_size = vocab_size
33
+ self.hidden_size = hidden_size
34
+ self.intermediate_size = intermediate_size
35
+ self.num_hidden_layers = num_hidden_layers
36
+ self.num_attention_heads = num_attention_heads
37
+
38
+ if num_key_value_heads is None:
39
+ num_key_value_heads = num_attention_heads
40
+
41
+ self.num_key_value_heads = num_key_value_heads
42
+ self.resid_pdrop = resid_pdrop
43
+ self.embd_pdrop = embd_pdrop
44
+ self.attention_dropout = attention_dropout
45
+ self.hidden_act = hidden_act
46
+ self.max_position_embeddings = max_position_embeddings
47
+ self.initializer_range = initializer_range
48
+ self.layer_norm_eps = layer_norm_eps
49
+ self.use_cache = use_cache
50
+ self.rope_theta = rope_theta
51
+ self.rope_scaling = rope_scaling
52
+ self.partial_rotary_factor = partial_rotary_factor
53
+ self._rope_scaling_validation()
54
+
55
+ super().__init__(
56
+ bos_token_id=bos_token_id,
57
+ eos_token_id=eos_token_id,
58
+ tie_word_embeddings=tie_word_embeddings,
59
+ **kwargs,
60
+ )
61
+
62
+ # Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
63
+ def _rope_scaling_validation(self):
64
+ """
65
+ Validate the `rope_scaling` configuration.
66
+ """
67
+ if self.rope_scaling is None:
68
+ return
69
+
70
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
71
+ raise ValueError(
72
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
73
+ f"got {self.rope_scaling}"
74
+ )
75
+ rope_scaling_type = self.rope_scaling.get("type", None)
76
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
77
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
78
+ raise ValueError(
79
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
80
+ )
81
+ if (
82
+ rope_scaling_factor is None
83
+ or not isinstance(rope_scaling_factor, float)
84
+ or rope_scaling_factor <= 1.0
85
+ ):
86
+ raise ValueError(
87
+ f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}"
88
+ )
89
+
90
+
91
+ class MoondreamConfig(PretrainedConfig):
92
+ model_type = "moondream1"
93
+
94
+ def __init__(self, **kwargs):
95
+ self.text_config = PhiConfig(**kwargs.pop("text_config", {}))
96
+ super().__init__(**kwargs)
fourier_features.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/crowsonkb/k-diffusion/blob/transformer-model-v2/k_diffusion/layers.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import math
6
+
7
+
8
+ class FourierFeatures(nn.Module):
9
+ def __init__(self, in_features, out_features, std=1.0):
10
+ super().__init__()
11
+ assert out_features % 2 == 0
12
+ self.register_buffer(
13
+ "weight", torch.randn([out_features // 2, in_features]) * std
14
+ )
15
+
16
+ def forward(self, input):
17
+ f = 2 * math.pi * input @ self.weight.T
18
+ return torch.cat([f.cos(), f.sin()], dim=-1)
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.44.0"
4
+ }
handler.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ from PIL import Image
3
+ import torch
4
+ from io import BytesIO
5
+ import base64
6
+
7
+ class EndpointHandler:
8
+ def __init__(self, model_dir):
9
+ self.model_id = "vikhyatk/moondream2"
10
+ self.model = AutoModelForCausalLM.from_pretrained(self.model_id, trust_remote_code=True)
11
+ self.tokenizer = AutoTokenizer.from_pretrained("vikhyatk/moondream2", trust_remote_code=True)
12
+
13
+ # Check if CUDA (GPU support) is available and then set the device to GPU or CPU
14
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ self.model.to(self.device)
16
+
17
+ def preprocess_image(self, encoded_image):
18
+ """Decode and preprocess the input image."""
19
+ decoded_image = base64.b64decode(encoded_image)
20
+ img = Image.open(BytesIO(decoded_image)).convert("RGB")
21
+ return img
22
+
23
+ def __call__(self, data):
24
+ """Handle the incoming request."""
25
+ try:
26
+ # Extract the inputs from the data
27
+ inputs = data.pop("inputs", data)
28
+ input_image = inputs['image']
29
+ question = inputs.get('question', "move to the red ball")
30
+
31
+ # Preprocess the image
32
+ img = self.preprocess_image(input_image)
33
+
34
+ # Perform inference
35
+ enc_image = self.model.encode_image(img).to(self.device)
36
+ answer = self.model.answer_question(enc_image, question, self.tokenizer)
37
+
38
+ # If the output is a tensor, move it back to CPU and convert to list
39
+ if isinstance(answer, torch.Tensor):
40
+ answer = answer.cpu().numpy().tolist()
41
+
42
+ # Create the response
43
+ response = {
44
+ "statusCode": 200,
45
+ "body": {
46
+ "answer": answer
47
+ }
48
+ }
49
+ return response
50
+ except Exception as e:
51
+ # Handle any errors
52
+ response = {
53
+ "statusCode": 500,
54
+ "body": {
55
+ "error": str(e)
56
+ }
57
+ }
58
+ return response
hf_moondream.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from transformers import PreTrainedModel, PretrainedConfig
5
+ from typing import Union
6
+
7
+ from .config import MoondreamConfig
8
+ from .moondream import MoondreamModel
9
+
10
+ # Files sometimes don't get loaded without these...
11
+ from .image_crops import *
12
+ from .vision import *
13
+ from .text import *
14
+ from .region import *
15
+ from .utils import *
16
+
17
+
18
+ def extract_question(text):
19
+ prefix = "<image>\n\nQuestion: "
20
+ suffix = "\n\nAnswer:"
21
+
22
+ if text.startswith(prefix) and text.endswith(suffix):
23
+ return text[len(prefix) : -len(suffix)]
24
+ else:
25
+ return None
26
+
27
+
28
+ class HfConfig(PretrainedConfig):
29
+ _auto_class = "AutoConfig"
30
+ model_type = "moondream1"
31
+
32
+ def __init__(self, **kwargs):
33
+ super().__init__(**kwargs)
34
+ self.config = {}
35
+
36
+
37
+ class HfMoondream(PreTrainedModel):
38
+ _auto_class = "AutoModelForCausalLM"
39
+ config_class = HfConfig
40
+
41
+ def __init__(self, config):
42
+ super().__init__(config)
43
+ self.model = MoondreamModel(
44
+ MoondreamConfig.from_dict(config.config), setup_caches=False
45
+ )
46
+ self._is_kv_cache_setup = False
47
+
48
+ def _setup_caches(self):
49
+ if not self._is_kv_cache_setup:
50
+ self.model._setup_caches()
51
+ self._is_kv_cache_setup = True
52
+
53
+ @property
54
+ def encode_image(self):
55
+ self._setup_caches()
56
+ return self.model.encode_image
57
+
58
+ @property
59
+ def query(self):
60
+ self._setup_caches()
61
+ return self.model.query
62
+
63
+ @property
64
+ def caption(self):
65
+ self._setup_caches()
66
+ return self.model.caption
67
+
68
+ @property
69
+ def detect(self):
70
+ self._setup_caches()
71
+ return self.model.detect
72
+
73
+ @property
74
+ def point(self):
75
+ self._setup_caches()
76
+ return self.model.point
77
+
78
+ @property
79
+ def detect_gaze(self):
80
+ self._setup_caches()
81
+ return self.model.detect_gaze
82
+
83
+ def answer_question(
84
+ self,
85
+ image_embeds,
86
+ question,
87
+ tokenizer=None,
88
+ chat_history="",
89
+ result_queue=None,
90
+ max_new_tokens=256,
91
+ **kwargs
92
+ ):
93
+ answer = self.query(image_embeds, question)["answer"].strip()
94
+
95
+ if result_queue is not None:
96
+ result_queue.put(answer)
97
+ return answer
98
+
99
+ def batch_answer(self, images, prompts, tokenizer=None, **kwargs):
100
+ answers = []
101
+ for image, prompt in zip(images, prompts):
102
+ answers.append(self.query(image, prompt)["answer"].strip())
103
+ return answers
104
+
105
+ def _unsupported_exception(self):
106
+ raise NotImplementedError(
107
+ "This method is not supported in the latest version of moondream. "
108
+ "Consider upgrading to the updated API spec, or alternately pin "
109
+ "to 'revision=2024-08-26'."
110
+ )
111
+
112
+ def generate(self, image_embeds, prompt, tokenizer, max_new_tokens=128, **kwargs):
113
+ """
114
+ Function definition remains unchanged for backwards compatibility.
115
+ Be aware that tokenizer, max_new_takens, and kwargs are ignored.
116
+ """
117
+ prompt_extracted = extract_question(prompt)
118
+ if prompt_extracted is not None:
119
+ answer = self.model.query(
120
+ image=image_embeds, question=prompt_extracted, stream=False
121
+ )["answer"]
122
+ else:
123
+ image_embeds = self.encode_image(image_embeds)
124
+ prompt_tokens = torch.tensor(
125
+ [self.model.tokenizer.encode(prompt).ids],
126
+ device=self.device,
127
+ )
128
+
129
+ def generator():
130
+ for token in self.model._generate_answer(
131
+ prompt_tokens,
132
+ image_embeds.kv_cache,
133
+ image_embeds.pos,
134
+ max_new_tokens,
135
+ ):
136
+ yield token
137
+
138
+ answer = "".join(list(generator()))
139
+
140
+ return [answer]
141
+
142
+ def get_input_embeddings(self) -> nn.Embedding:
143
+ """
144
+ Lazily wrap the raw parameter `self.model.text.wte` in a real
145
+ `nn.Embedding` layer so that HF mix-ins recognise it. The wrapper
146
+ **shares** the weight tensor—no copy is made.
147
+ """
148
+ if not hasattr(self, "_input_embeddings"):
149
+ self._input_embeddings = nn.Embedding.from_pretrained(
150
+ self.model.text.wte, # tensor created in text.py
151
+ freeze=True, # set to False if you need it trainable
152
+ )
153
+ return self._input_embeddings
154
+
155
+ def set_input_embeddings(self, value: Union[nn.Embedding, nn.Module]) -> None:
156
+ """
157
+ Lets HF functions (e.g. `resize_token_embeddings`) replace or resize the
158
+ embeddings and keeps everything tied to `self.model.text.wte`.
159
+ """
160
+ # 1. point the low-level parameter to the new weight matrix
161
+ self.model.text.wte = value.weight
162
+ # 2. keep a reference for get_input_embeddings()
163
+ self._input_embeddings = value
164
+
165
+ def input_embeds(
166
+ self,
167
+ input_ids: Union[torch.LongTensor, list, tuple],
168
+ *,
169
+ device: torch.device | None = None
170
+ ) -> torch.FloatTensor:
171
+ """
172
+ Back-compat wrapper that turns token IDs into embeddings.
173
+
174
+ Example:
175
+ ids = torch.tensor([[1, 2, 3]])
176
+ embeds = model.input_embeds(ids) # (1, 3, hidden_dim)
177
+ """
178
+ if not torch.is_tensor(input_ids):
179
+ input_ids = torch.as_tensor(input_ids)
180
+ if device is not None:
181
+ input_ids = input_ids.to(device)
182
+
183
+ return self.get_input_embeddings()(input_ids)
image_crops.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+
5
+ from typing import TypedDict
6
+
7
+ try:
8
+ import pyvips
9
+
10
+ HAS_VIPS = True
11
+ except:
12
+ from PIL import Image
13
+
14
+ HAS_VIPS = False
15
+
16
+
17
+ def select_tiling(
18
+ height: int, width: int, crop_size: int, max_crops: int
19
+ ) -> tuple[int, int]:
20
+ """
21
+ Determine the optimal number of tiles to cover an image with overlapping crops.
22
+ """
23
+ if height <= crop_size or width <= crop_size:
24
+ return (1, 1)
25
+
26
+ # Minimum required tiles in each dimension
27
+ min_h = math.ceil(height / crop_size)
28
+ min_w = math.ceil(width / crop_size)
29
+
30
+ # If minimum required tiles exceed max_crops, return proportional distribution
31
+ if min_h * min_w > max_crops:
32
+ ratio = math.sqrt(max_crops / (min_h * min_w))
33
+ return (max(1, math.floor(min_h * ratio)), max(1, math.floor(min_w * ratio)))
34
+
35
+ # Perfect aspect-ratio tiles that satisfy max_crops
36
+ h_tiles = math.floor(math.sqrt(max_crops * height / width))
37
+ w_tiles = math.floor(math.sqrt(max_crops * width / height))
38
+
39
+ # Ensure we meet minimum tile requirements
40
+ h_tiles = max(h_tiles, min_h)
41
+ w_tiles = max(w_tiles, min_w)
42
+
43
+ # If we exceeded max_crops, scale down the larger dimension
44
+ if h_tiles * w_tiles > max_crops:
45
+ if w_tiles > h_tiles:
46
+ w_tiles = math.floor(max_crops / h_tiles)
47
+ else:
48
+ h_tiles = math.floor(max_crops / w_tiles)
49
+
50
+ return (max(1, h_tiles), max(1, w_tiles))
51
+
52
+
53
+ class OverlapCropOutput(TypedDict):
54
+ crops: np.ndarray
55
+ tiling: tuple[int, int]
56
+
57
+
58
+ def overlap_crop_image(
59
+ image: np.ndarray,
60
+ overlap_margin: int,
61
+ max_crops: int,
62
+ base_size: tuple[int, int] = (378, 378),
63
+ patch_size: int = 14,
64
+ ) -> OverlapCropOutput:
65
+ """
66
+ Process an image using an overlap-and-resize cropping strategy with margin handling.
67
+
68
+ This function takes an input image and creates multiple overlapping crops with
69
+ consistent margins. It produces:
70
+ 1. A single global crop resized to base_size
71
+ 2. Multiple overlapping local crops that maintain high resolution details
72
+ 3. A patch ordering matrix that tracks correspondence between crops
73
+
74
+ The overlap strategy ensures:
75
+ - Smooth transitions between adjacent crops
76
+ - No loss of information at crop boundaries
77
+ - Proper handling of features that cross crop boundaries
78
+ - Consistent patch indexing across the full image
79
+
80
+ Args:
81
+ image (np.ndarray): Input image as numpy array with shape (H,W,C)
82
+ base_size (tuple[int,int]): Target size for crops, default (378,378)
83
+ patch_size (int): Size of patches in pixels, default 14
84
+ overlap_margin (int): Margin size in patch units, default 4
85
+ max_crops (int): Maximum number of crops allowed, default 12
86
+
87
+ Returns:
88
+ OverlapCropOutput: Dictionary containing:
89
+ - crops: A numpy array containing the global crop of the full image (index 0)
90
+ followed by the overlapping cropped regions (indices 1+)
91
+ - tiling: Tuple of (height,width) tile counts
92
+ """
93
+ original_h, original_w = image.shape[:2]
94
+
95
+ # Convert margin from patch units to pixels
96
+ margin_pixels = patch_size * overlap_margin
97
+ total_margin_pixels = margin_pixels * 2 # Both sides
98
+
99
+ # Calculate crop parameters
100
+ crop_patches = base_size[0] // patch_size # patches per crop dimension
101
+ crop_window_patches = crop_patches - (2 * overlap_margin) # usable patches
102
+ crop_window_size = crop_window_patches * patch_size # usable size in pixels
103
+
104
+ # Determine tiling
105
+ tiling = select_tiling(
106
+ original_h - total_margin_pixels,
107
+ original_w - total_margin_pixels,
108
+ crop_window_size,
109
+ max_crops,
110
+ )
111
+
112
+ # Pre-allocate crops.
113
+ n_crops = tiling[0] * tiling[1] + 1 # 1 = global crop
114
+ crops = np.zeros(
115
+ (n_crops, base_size[0], base_size[1], image.shape[2]), dtype=np.uint8
116
+ )
117
+
118
+ # Resize image to fit tiling
119
+ target_size = (
120
+ tiling[0] * crop_window_size + total_margin_pixels,
121
+ tiling[1] * crop_window_size + total_margin_pixels,
122
+ )
123
+
124
+ if HAS_VIPS:
125
+ # Convert to vips for resizing
126
+ vips_image = pyvips.Image.new_from_array(image)
127
+ scale_x = target_size[1] / image.shape[1]
128
+ scale_y = target_size[0] / image.shape[0]
129
+ resized = vips_image.resize(scale_x, vscale=scale_y)
130
+ image = resized.numpy()
131
+
132
+ # Create global crop
133
+ scale_x = base_size[1] / vips_image.width
134
+ scale_y = base_size[0] / vips_image.height
135
+ global_vips = vips_image.resize(scale_x, vscale=scale_y)
136
+ crops[0] = global_vips.numpy()
137
+ else:
138
+ # Fallback to PIL
139
+ pil_img = Image.fromarray(image)
140
+ resized = pil_img.resize(
141
+ (int(target_size[1]), int(target_size[0])),
142
+ resample=Image.Resampling.LANCZOS,
143
+ )
144
+ image = np.asarray(resized)
145
+
146
+ # Create global crop
147
+ global_pil = pil_img.resize(
148
+ (int(base_size[1]), int(base_size[0])), resample=Image.Resampling.LANCZOS
149
+ )
150
+ crops[0] = np.asarray(global_pil)
151
+
152
+ for i in range(tiling[0]):
153
+ for j in range(tiling[1]):
154
+ # Calculate crop coordinates
155
+ y0 = i * crop_window_size
156
+ x0 = j * crop_window_size
157
+
158
+ # Extract crop with padding if needed
159
+ y_end = min(y0 + base_size[0], image.shape[0])
160
+ x_end = min(x0 + base_size[1], image.shape[1])
161
+
162
+ crop_region = image[y0:y_end, x0:x_end]
163
+ crops[
164
+ 1 + i * tiling[1] + j, : crop_region.shape[0], : crop_region.shape[1]
165
+ ] = crop_region
166
+
167
+ return {"crops": crops, "tiling": tiling}
168
+
169
+
170
+ def reconstruct_from_crops(
171
+ crops: torch.Tensor,
172
+ tiling: tuple[int, int],
173
+ overlap_margin: int,
174
+ patch_size: int = 14,
175
+ ) -> torch.Tensor:
176
+ """
177
+ Reconstruct the original image from overlapping crops into a single seamless image.
178
+
179
+ Takes a list of overlapping image crops along with their positional metadata and
180
+ reconstructs them into a single coherent image by carefully stitching together
181
+ non-overlapping regions. Handles both numpy arrays and PyTorch tensors.
182
+
183
+ Args:
184
+ crops: List of image crops as numpy arrays or PyTorch tensors with shape
185
+ (H,W,C)
186
+ tiling: Tuple of (height,width) indicating crop grid layout
187
+ patch_size: Size in pixels of each patch, default 14
188
+ overlap_margin: Number of overlapping patches on each edge, default 4
189
+
190
+ Returns:
191
+ Reconstructed image as numpy array or PyTorch tensor matching input type,
192
+ with shape (H,W,C) where H,W are the original image dimensions
193
+ """
194
+ tiling_h, tiling_w = tiling
195
+ crop_height, crop_width = crops[0].shape[:2]
196
+ margin_pixels = overlap_margin * patch_size
197
+
198
+ # Calculate output size (only adding margins once)
199
+ output_h = (crop_height - 2 * margin_pixels) * tiling_h + 2 * margin_pixels
200
+ output_w = (crop_width - 2 * margin_pixels) * tiling_w + 2 * margin_pixels
201
+
202
+ reconstructed = torch.zeros(
203
+ (output_h, output_w, crops[0].shape[2]),
204
+ device=crops[0].device,
205
+ dtype=crops[0].dtype,
206
+ )
207
+
208
+ for i, crop in enumerate(crops):
209
+ tile_y = i // tiling_w
210
+ tile_x = i % tiling_w
211
+
212
+ # For each tile, determine which part to keep
213
+ # Keep left margin only for first column
214
+ x_start = 0 if tile_x == 0 else margin_pixels
215
+ # Keep right margin only for last column
216
+ x_end = crop_width if tile_x == tiling_w - 1 else crop_width - margin_pixels
217
+ # Keep top margin only for first row
218
+ y_start = 0 if tile_y == 0 else margin_pixels
219
+ # Keep bottom margin only for last row
220
+ y_end = crop_height if tile_y == tiling_h - 1 else crop_height - margin_pixels
221
+
222
+ # Calculate where this piece belongs in the output
223
+ out_x = tile_x * (crop_width - 2 * margin_pixels)
224
+ out_y = tile_y * (crop_height - 2 * margin_pixels)
225
+
226
+ # Place the piece
227
+ reconstructed[
228
+ out_y + y_start : out_y + y_end, out_x + x_start : out_x + x_end
229
+ ] = crop[y_start:y_end, x_start:x_end]
230
+
231
+ return reconstructed
layers.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Literal, Optional
7
+
8
+ try:
9
+ from torchao import quantize_
10
+ from torchao.quantization import int4_weight_only
11
+ except ImportError:
12
+
13
+ def quantize_(model, quant_mode):
14
+ raise ImportError(
15
+ "torchao is not installed. Please install it with `pip install torchao`."
16
+ )
17
+
18
+ def int4_weight_only(group_size):
19
+ raise ImportError(
20
+ "torchao is not installed. Please install it with `pip install torchao`."
21
+ )
22
+
23
+
24
+ def gelu_approx(x):
25
+ return F.gelu(x, approximate="tanh")
26
+
27
+
28
+ @dataclass
29
+ class LinearWeights:
30
+ weight: torch.Tensor
31
+ bias: torch.Tensor
32
+
33
+
34
+ def linear(x: torch.Tensor, w: LinearWeights) -> torch.Tensor:
35
+ return F.linear(x, w.weight, w.bias)
36
+
37
+
38
+ def dequantize_tensor(W_q, scale, zero, orig_shape, dtype=torch.bfloat16):
39
+ _step = W_q.shape[0]
40
+ W_r = torch.empty([2 * _step, W_q.shape[1]], dtype=dtype, device=W_q.device)
41
+ W_r[:_step] = (W_q & 0b11110000) >> 4
42
+ W_r[_step:] = W_q & 0b00001111
43
+ W_r.sub_(zero).mul_(scale)
44
+ return W_r.reshape(orig_shape)
45
+
46
+
47
+ class QuantizedLinear(nn.Module):
48
+ def __init__(
49
+ self,
50
+ in_features: int,
51
+ out_features: int,
52
+ dtype: torch.dtype,
53
+ ):
54
+ # TODO: Take group_size as an input instead of hardcoding it here.
55
+ super().__init__()
56
+ self.in_features = in_features
57
+ self.out_features = out_features
58
+ self.weight = nn.ParameterDict(
59
+ {
60
+ "packed": nn.Parameter(
61
+ torch.empty(
62
+ out_features * in_features // (128 * 2), 128, dtype=torch.uint8
63
+ ),
64
+ requires_grad=False,
65
+ ),
66
+ "scale": nn.Parameter(
67
+ torch.empty(out_features * in_features // 128, 1),
68
+ requires_grad=False,
69
+ ),
70
+ "zero_point": nn.Parameter(
71
+ torch.empty(out_features * in_features // 128, 1),
72
+ requires_grad=False,
73
+ ),
74
+ }
75
+ )
76
+ self.bias = nn.Parameter(torch.empty(out_features), requires_grad=False)
77
+ self.unpacked = False
78
+
79
+ def unpack(self):
80
+ if self.unpacked:
81
+ return
82
+
83
+ self.weight = nn.Parameter(
84
+ dequantize_tensor(
85
+ self.weight["packed"],
86
+ self.weight["scale"],
87
+ self.weight["zero_point"],
88
+ (self.out_features, self.in_features),
89
+ torch.bfloat16,
90
+ )
91
+ )
92
+ with torch.device("meta"):
93
+ self.linear = nn.Linear(
94
+ self.in_features, self.out_features, dtype=torch.bfloat16
95
+ )
96
+ self.linear.weight = self.weight
97
+ self.linear.bias = nn.Parameter(
98
+ self.bias.to(torch.bfloat16), requires_grad=False
99
+ )
100
+
101
+ del self.weight, self.bias
102
+ quantize_(self, int4_weight_only(group_size=128))
103
+ self.unpacked = True
104
+ torch.cuda.empty_cache()
105
+
106
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
107
+ if not self.unpacked:
108
+ self.unpack()
109
+ return self.linear(x)
110
+
111
+
112
+ @dataclass
113
+ class LayerNormWeights:
114
+ weight: torch.Tensor
115
+ bias: torch.Tensor
116
+
117
+
118
+ def layer_norm(x: torch.Tensor, w: LayerNormWeights) -> torch.Tensor:
119
+ return F.layer_norm(x, w.bias.shape, w.weight, w.bias)
120
+
121
+
122
+ @dataclass
123
+ class MLPWeights:
124
+ fc1: LinearWeights
125
+ fc2: LinearWeights
126
+ act: Literal["gelu_approx"] = "gelu_approx"
127
+
128
+
129
+ def mlp(x: torch.Tensor, w: MLPWeights, lora: Optional[dict] = None) -> torch.Tensor:
130
+ x0 = w.fc1(x)
131
+ if lora is not None:
132
+ x1 = F.linear(F.linear(x, lora["fc1"]["A"]), lora["fc1"]["B"])
133
+ x = x0 + x1
134
+ else:
135
+ x = x0
136
+
137
+ x = gelu_approx(x)
138
+
139
+ x0 = w.fc2(x)
140
+ if lora is not None:
141
+ x1 = F.linear(F.linear(x, lora["fc2"]["A"]), lora["fc2"]["B"])
142
+ x = x0 + x1
143
+ else:
144
+ x = x0
145
+
146
+ return x
147
+
148
+
149
+ @dataclass
150
+ class AttentionWeights:
151
+ qkv: LinearWeights
152
+ proj: LinearWeights
153
+
154
+
155
+ def attn(x: torch.Tensor, w: AttentionWeights, n_heads: int) -> torch.Tensor:
156
+ bsz, q_len, d_model = x.shape
157
+ head_dim = d_model // n_heads
158
+
159
+ q, k, v = [
160
+ t.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
161
+ for t in linear(x, w.qkv).chunk(3, dim=-1)
162
+ ]
163
+ out = F.scaled_dot_product_attention(q, k, v)
164
+ out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
165
+ out = linear(out, w.proj)
166
+ return out
lora.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import os
3
+ import shutil
4
+ import torch
5
+
6
+ from pathlib import Path
7
+ from urllib.request import Request, urlopen
8
+ from typing import Optional
9
+
10
+
11
+ def variant_cache_dir():
12
+ hf_hub_cache = os.environ.get("HF_HUB_CACHE")
13
+ if hf_hub_cache is not None:
14
+ return Path(hf_hub_cache) / "md_variants"
15
+
16
+ hf_home = os.environ.get("HF_HOME")
17
+ if hf_home is not None:
18
+ return Path(hf_home) / "hub" / "md_variants"
19
+
20
+ return Path("~/.cache/huggingface/hub").expanduser() / "md_variants"
21
+
22
+
23
+ def cached_variant_path(variant_id: str):
24
+ variant, *rest = variant_id.split("/", 1)
25
+ step = rest[0] if rest else "final"
26
+
27
+ cache_dir = variant_cache_dir() / variant
28
+ os.makedirs(cache_dir, exist_ok=True)
29
+ dest = cache_dir / f"{step}.pt"
30
+ if dest.exists():
31
+ return dest
32
+
33
+ md_endpoint = os.getenv("MOONDREAM_ENDPOINT", "https://api.moondream.ai")
34
+
35
+ headers = {"User-Agent": "moondream-torch"}
36
+ api_key = os.getenv("MOONDREAM_API_KEY")
37
+ if api_key is not None:
38
+ headers["X-Moondream-Auth"] = api_key
39
+
40
+ req = Request(f"{md_endpoint}/v1/variants/{variant_id}/download", headers=headers)
41
+ with urlopen(req) as r, open(dest, "wb") as f:
42
+ shutil.copyfileobj(r, f)
43
+ return dest
44
+
45
+
46
+ def nest(flat):
47
+ tree = {}
48
+ for k, v in flat.items():
49
+ parts = k.split(".")
50
+ d = tree
51
+ for p in parts[:-1]:
52
+ d = d.setdefault(p, {})
53
+ d[parts[-1]] = v
54
+ return tree
55
+
56
+
57
+ @functools.lru_cache(maxsize=5)
58
+ def variant_state_dict(variant_id: Optional[str] = None, device: str = "cpu"):
59
+ if variant_id is None:
60
+ return None
61
+
62
+ state_dict = torch.load(
63
+ cached_variant_path(variant_id), map_location=device, weights_only=True
64
+ )
65
+
66
+ # TODO: Move these into the training code that saves checkpoints...
67
+ rename_rules = [
68
+ ("text_model.transformer.h", "text.blocks"),
69
+ (".mixer", ".attn"),
70
+ (".out_proj", ".proj"),
71
+ (".Wqkv", ".qkv"),
72
+ (".parametrizations.weight.0", ""),
73
+ ]
74
+ new_state_dict = {}
75
+ for key, tensor in state_dict.items():
76
+ new_key = key
77
+ for old, new in rename_rules:
78
+ if old in new_key:
79
+ new_key = new_key.replace(old, new)
80
+ new_state_dict[new_key] = tensor
81
+
82
+ return nest(new_state_dict)
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:70a7d94c0c8349eb58ed2d9e636ef2d0916960f321ecabeac6354b8ba3d7403f
3
+ size 3854538968
moondream.py ADDED
@@ -0,0 +1,986 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import random
4
+
5
+ from typing import Literal, Tuple, TypedDict, Union, Dict, Any, Optional, List
6
+ from PIL import Image
7
+ from dataclasses import dataclass
8
+ from tokenizers import Tokenizer
9
+
10
+ from .config import MoondreamConfig
11
+ from .image_crops import reconstruct_from_crops
12
+ from .vision import vision_encoder, vision_projection, prepare_crops, build_vision_model
13
+ from .text import build_text_model, text_encoder, lm_head, text_decoder
14
+ from .region import (
15
+ decode_coordinate,
16
+ encode_coordinate,
17
+ decode_size,
18
+ encode_size,
19
+ encode_spatial_refs,
20
+ SpatialRefs,
21
+ )
22
+ from .layers import QuantizedLinear
23
+ from .lora import variant_state_dict
24
+ from .utils import remove_outlier_points
25
+
26
+ ImageEncodingSettings = TypedDict(
27
+ "ImageEncodingSettings",
28
+ {"variant": str},
29
+ total=False,
30
+ )
31
+
32
+ TextSamplingSettings = TypedDict(
33
+ "TextSamplingSettings",
34
+ {
35
+ "max_tokens": int,
36
+ "temperature": float,
37
+ "top_p": float,
38
+ "variant": str,
39
+ },
40
+ total=False,
41
+ )
42
+
43
+ ObjectSamplingSettings = TypedDict(
44
+ "ObjectSamplingSettings",
45
+ {"max_objects": int, "variant": str},
46
+ total=False,
47
+ )
48
+
49
+
50
+ DEFAULT_MAX_TOKENS = 768
51
+ DEFAULT_TEMPERATURE = 0.5
52
+ DEFAULT_TOP_P = 0.3
53
+ DEFAULT_MAX_OBJECTS = 50
54
+
55
+
56
+ @dataclass(frozen=True)
57
+ class EncodedImage:
58
+ pos: int
59
+ caches: List[Tuple[torch.Tensor, torch.Tensor]]
60
+
61
+
62
+ class KVCache(nn.Module):
63
+
64
+ def __init__(self, n_heads, n_kv_heads, max_context, dim, device, dtype):
65
+ super().__init__()
66
+ cache_shape = (1, n_kv_heads, max_context, dim // n_heads)
67
+ self.register_buffer(
68
+ "k_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
69
+ )
70
+ self.register_buffer(
71
+ "v_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
72
+ )
73
+
74
+ def update(self, pos_ids, k, v):
75
+ kout, vout = self.k_cache, self.v_cache
76
+ kout[:, :, pos_ids, :] = k
77
+ vout[:, :, pos_ids, :] = v
78
+ return kout, vout
79
+
80
+
81
+ class MoondreamModel(nn.Module):
82
+
83
+ def __init__(
84
+ self, config: MoondreamConfig, dtype=torch.bfloat16, setup_caches=True
85
+ ):
86
+ super().__init__()
87
+ self.config = config
88
+
89
+ self.tokenizer = Tokenizer.from_pretrained("moondream/starmie-v1")
90
+ self.vision = build_vision_model(config.vision, dtype)
91
+ self.text = build_text_model(config.text, dtype)
92
+
93
+ # Region Model
94
+ linear_cls = (
95
+ QuantizedLinear if config.region.group_size is not None else nn.Linear
96
+ )
97
+ self.region = nn.ModuleDict(
98
+ {
99
+ "coord_encoder": linear_cls(
100
+ config.region.coord_feat_dim, config.region.dim, dtype=dtype
101
+ ),
102
+ "coord_decoder": nn.ModuleDict(
103
+ {
104
+ "fc1": linear_cls(
105
+ config.region.dim, config.region.inner_dim, dtype=dtype
106
+ ),
107
+ "fc2": linear_cls(
108
+ config.region.inner_dim,
109
+ config.region.coord_out_dim,
110
+ dtype=dtype,
111
+ ),
112
+ }
113
+ ),
114
+ "size_encoder": linear_cls(
115
+ config.region.size_feat_dim, config.region.dim, dtype=dtype
116
+ ),
117
+ "size_decoder": nn.ModuleDict(
118
+ {
119
+ "fc1": linear_cls(
120
+ config.region.dim, config.region.inner_dim, dtype=dtype
121
+ ),
122
+ "fc2": linear_cls(
123
+ config.region.inner_dim,
124
+ config.region.size_out_dim,
125
+ dtype=dtype,
126
+ ),
127
+ }
128
+ ),
129
+ }
130
+ )
131
+ self.region.coord_features = nn.Parameter(
132
+ torch.empty(config.region.coord_feat_dim // 2, 1, dtype=dtype).T
133
+ )
134
+ self.region.size_features = nn.Parameter(
135
+ torch.empty(config.region.size_feat_dim // 2, 2, dtype=dtype).T
136
+ )
137
+
138
+ attn_mask = torch.tril(
139
+ torch.ones(
140
+ 1, 1, config.text.max_context, config.text.max_context, dtype=torch.bool
141
+ )
142
+ )
143
+ patch_w = config.vision.crop_size // config.vision.enc_patch_size
144
+ prefix_attn_len = 1 + patch_w**2
145
+ attn_mask[..., :prefix_attn_len, :prefix_attn_len] = 1
146
+ self.register_buffer("attn_mask", attn_mask, persistent=False)
147
+
148
+ # Initialize KV caches.
149
+ if setup_caches:
150
+ self._setup_caches()
151
+
152
+ def _setup_caches(self):
153
+ c = self.config.text
154
+ for b in self.text.blocks:
155
+ b.kv_cache = KVCache(
156
+ c.n_heads,
157
+ c.n_kv_heads,
158
+ c.max_context,
159
+ c.dim,
160
+ device=self.device,
161
+ dtype=self.vision.pos_emb.dtype,
162
+ )
163
+
164
+ @property
165
+ def device(self):
166
+ return self.vision.pos_emb.device
167
+
168
+ def _vis_enc(self, x: torch.Tensor):
169
+ return vision_encoder(x, self.vision, self.config.vision)
170
+
171
+ def _vis_proj(self, g: torch.Tensor, r: torch.Tensor):
172
+ return vision_projection(g, r, self.vision, self.config.vision)
173
+
174
+ def _prefill(
175
+ self,
176
+ x: torch.Tensor,
177
+ attn_mask: torch.Tensor,
178
+ pos_ids: torch.Tensor,
179
+ lora: Optional[torch.Tensor],
180
+ ):
181
+ return text_decoder(x, self.text, attn_mask, pos_ids, self.config.text, lora)
182
+
183
+ def _decode_one_tok(
184
+ self,
185
+ x: torch.Tensor,
186
+ attn_mask: torch.Tensor,
187
+ pos_ids: torch.Tensor,
188
+ lora: Optional[torch.Tensor],
189
+ ):
190
+ hidden = text_decoder(x, self.text, attn_mask, pos_ids, self.config.text, lora)
191
+ logits = lm_head(hidden, self.text)
192
+ return logits, hidden
193
+
194
+ def compile(self):
195
+ for module in self.modules():
196
+ if isinstance(module, QuantizedLinear):
197
+ module.unpack()
198
+
199
+ # TODO: vision_projection is not being compiled
200
+ self._vis_enc = torch.compile(self._vis_enc, fullgraph=True)
201
+ self._prefill = torch.compile(self._prefill, fullgraph=True)
202
+ self._decode_one_tok = torch.compile(
203
+ self._decode_one_tok, fullgraph=True, mode="reduce-overhead"
204
+ )
205
+
206
+ def _run_vision_encoder(self, image: Image.Image) -> torch.Tensor:
207
+ all_crops, tiling = prepare_crops(image, self.config.vision, device=self.device)
208
+
209
+ torch._dynamo.mark_dynamic(all_crops, 0)
210
+
211
+ outputs = self._vis_enc(all_crops)
212
+
213
+ global_features = outputs[0]
214
+ local_features = outputs[1:].view(
215
+ -1,
216
+ self.config.vision.enc_n_layers,
217
+ self.config.vision.enc_n_layers,
218
+ self.config.vision.enc_dim,
219
+ )
220
+
221
+ reconstructed = reconstruct_from_crops(
222
+ local_features,
223
+ tiling,
224
+ patch_size=1,
225
+ overlap_margin=self.config.vision.overlap_margin,
226
+ )
227
+
228
+ return self._vis_proj(global_features, reconstructed)
229
+
230
+ def encode_image(
231
+ self,
232
+ image: Union[Image.Image, EncodedImage],
233
+ settings: Optional[ImageEncodingSettings] = None,
234
+ ) -> EncodedImage:
235
+ if isinstance(image, EncodedImage):
236
+ return image
237
+ elif not isinstance(image, Image.Image):
238
+ raise ValueError("image must be a PIL Image or EncodedImage")
239
+
240
+ lora = (
241
+ variant_state_dict(settings["variant"], device=self.device)
242
+ if settings is not None and settings["variant"] is not None
243
+ else None
244
+ )
245
+
246
+ # Run through text model in addition to the vision encoder, to minimize
247
+ # re-computation if multiple queries are performed on this image.
248
+ with torch.inference_mode():
249
+ img_emb = self._run_vision_encoder(image)
250
+ bos_emb = text_encoder(
251
+ torch.tensor([[self.config.tokenizer.bos_id]], device=self.device),
252
+ self.text,
253
+ )
254
+ inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1)
255
+ mask = self.attn_mask[:, :, 0 : inputs_embeds.size(1), :]
256
+ pos_ids = torch.arange(inputs_embeds.size(1), dtype=torch.long)
257
+ self._prefill(inputs_embeds, mask, pos_ids, lora)
258
+
259
+ return EncodedImage(
260
+ pos=inputs_embeds.size(1),
261
+ caches=[
262
+ (
263
+ b.kv_cache.k_cache[:, :, : inputs_embeds.size(1), :].clone(),
264
+ b.kv_cache.v_cache[:, :, : inputs_embeds.size(1), :].clone(),
265
+ )
266
+ for b in self.text.blocks
267
+ ],
268
+ )
269
+
270
+ def _apply_top_p(self, probs: torch.Tensor, top_p: float):
271
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
272
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
273
+ mask = probs_sum - probs_sort > top_p
274
+ probs_sort[mask] = 0.0
275
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
276
+ next_probs = torch.zeros_like(probs)
277
+ next_probs.scatter_(dim=-1, index=probs_idx, src=probs_sort)
278
+ return next_probs
279
+
280
+ def _prefill_prompt(
281
+ self,
282
+ prompt_tokens: torch.Tensor,
283
+ pos: int,
284
+ temperature: float,
285
+ top_p: float,
286
+ spatial_refs: Optional[SpatialRefs] = None,
287
+ attn_mask: Optional[torch.Tensor] = None,
288
+ lora: Optional[dict] = None,
289
+ ):
290
+ with torch.inference_mode():
291
+ prompt_emb = text_encoder(prompt_tokens, self.text)
292
+
293
+ if spatial_refs:
294
+ encoded_refs = encode_spatial_refs(spatial_refs, self.region)
295
+ prompt_emb[prompt_tokens == self.config.tokenizer.coord_id] = (
296
+ encoded_refs["coords"]
297
+ )
298
+ if encoded_refs["sizes"] is not None:
299
+ prompt_emb[prompt_tokens == self.config.tokenizer.size_id] = (
300
+ encoded_refs["sizes"]
301
+ )
302
+
303
+ torch._dynamo.mark_dynamic(prompt_emb, 1)
304
+
305
+ if attn_mask is None:
306
+ attn_mask = self.attn_mask
307
+
308
+ mask = attn_mask[:, :, pos : pos + prompt_emb.size(1), :]
309
+ pos_ids = torch.arange(pos, pos + prompt_emb.size(1), dtype=torch.long)
310
+ hidden_BC = self._prefill(prompt_emb, mask, pos_ids, lora)
311
+ logits_BV = lm_head(hidden_BC, self.text)
312
+
313
+ if temperature == 0:
314
+ next_token = torch.argmax(logits_BV, dim=-1).unsqueeze(1)
315
+ else:
316
+ probs = torch.softmax(logits_BV / temperature, dim=-1)
317
+ probs = self._apply_top_p(probs, top_p)
318
+ next_token = torch.multinomial(probs, num_samples=1)
319
+
320
+ pos = pos + prompt_emb.size(1)
321
+ return logits_BV, hidden_BC, next_token, pos
322
+
323
+ def _generate_reasoning(
324
+ self,
325
+ prompt_tokens,
326
+ pos,
327
+ settings: Optional[TextSamplingSettings] = None,
328
+ spatial_refs: Optional[SpatialRefs] = None,
329
+ attn_mask: Optional[torch.Tensor] = None,
330
+ ) -> Tuple[int, str, List[dict]]:
331
+ max_tokens = (
332
+ settings.get("max_tokens", DEFAULT_MAX_TOKENS)
333
+ if settings
334
+ else DEFAULT_MAX_TOKENS
335
+ )
336
+ temperature = (
337
+ settings.get("temperature", DEFAULT_TEMPERATURE)
338
+ if settings
339
+ else DEFAULT_TEMPERATURE
340
+ )
341
+ lora = (
342
+ variant_state_dict(settings["variant"], device=self.device)
343
+ if settings is not None and "variant" in settings
344
+ else None
345
+ )
346
+
347
+ top_p = settings.get("top_p", DEFAULT_TOP_P) if settings else DEFAULT_TOP_P
348
+ eos_id = self.config.tokenizer.answer_id
349
+
350
+ _, last_hidden_BC, next_token, pos = self._prefill_prompt(
351
+ prompt_tokens,
352
+ pos,
353
+ temperature,
354
+ top_p,
355
+ spatial_refs,
356
+ attn_mask=attn_mask,
357
+ lora=lora,
358
+ )
359
+
360
+ text_token_chunks = [[]]
361
+ grounding_chunks = [[]]
362
+
363
+ mask = torch.zeros(1, 1, 2048, device=self.device, dtype=torch.bool)
364
+ mask[:, :, :pos] = 1
365
+ pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long)
366
+ generated_tokens = 0
367
+
368
+ while (
369
+ next_token_id := next_token.item()
370
+ ) != eos_id and generated_tokens < max_tokens:
371
+ if (
372
+ next_token_id == self.config.tokenizer.start_ground_points_id
373
+ or next_token_id == self.config.tokenizer.end_ground_id
374
+ ):
375
+ text_token_chunks.append([])
376
+ grounding_chunks.append([])
377
+
378
+ text_token_chunks[-1].append(next_token_id)
379
+
380
+ with torch.inference_mode():
381
+ if next_token_id == self.config.tokenizer.coord_id:
382
+ coord_logits = decode_coordinate(last_hidden_BC, self.region)
383
+ coord = torch.argmax(coord_logits, dim=-1) / coord_logits.size(-1)
384
+ grounding_chunks[-1].append(coord.item())
385
+
386
+ next_emb = encode_coordinate(
387
+ coord.to(dtype=coord_logits.dtype), self.region
388
+ ).unsqueeze(0)
389
+ else:
390
+ next_emb = text_encoder(next_token, self.text)
391
+
392
+ mask[:, :, pos], pos_ids[0] = 1, pos
393
+
394
+ logits_BV, last_hidden_BC = self._decode_one_tok(
395
+ next_emb, mask, pos_ids, lora
396
+ )
397
+ logits_BV[:, self.config.tokenizer.eos_id] = float("-inf")
398
+ logits_BV[:, self.config.tokenizer.size_id] = float("-inf")
399
+
400
+ pos += 1
401
+
402
+ if temperature == 0:
403
+ next_token = torch.argmax(logits_BV, dim=-1).unsqueeze(1) # (1, 1)
404
+ else:
405
+ probs = torch.softmax(logits_BV / temperature, dim=-1) # (1, V)
406
+ probs = self._apply_top_p(probs, top_p)
407
+ next_token = torch.multinomial(probs, num_samples=1) # (1, 1)
408
+
409
+ generated_tokens += 1
410
+
411
+ text_chunks = [
412
+ self.tokenizer.decode(chunk_tokens) for chunk_tokens in text_token_chunks
413
+ ]
414
+ text = "".join(text_chunks)
415
+
416
+ start_idx = 0
417
+ grounding = []
418
+ for text_chunk, grounding_chunk in zip(text_chunks, grounding_chunks):
419
+ if len(grounding_chunk) > 1:
420
+ points = []
421
+ for i in range(0, len(grounding_chunk) - (len(grounding_chunk) % 2), 2):
422
+ points.append((grounding_chunk[i], grounding_chunk[i + 1]))
423
+ grounding.append(
424
+ {
425
+ "start_idx": start_idx,
426
+ "end_idx": start_idx + len(text_chunk),
427
+ "points": points,
428
+ }
429
+ )
430
+ start_idx += len(text_chunk)
431
+
432
+ return pos, text, grounding
433
+
434
+ def _generate_answer(
435
+ self,
436
+ prompt_tokens: torch.Tensor,
437
+ pos: int,
438
+ settings: Optional[TextSamplingSettings] = None,
439
+ spatial_refs: Optional[SpatialRefs] = None,
440
+ eos_id: Optional[int] = None,
441
+ attn_mask: Optional[torch.Tensor] = None,
442
+ ):
443
+ max_tokens = (
444
+ settings.get("max_tokens", DEFAULT_MAX_TOKENS)
445
+ if settings
446
+ else DEFAULT_MAX_TOKENS
447
+ )
448
+ temperature = (
449
+ settings.get("temperature", DEFAULT_TEMPERATURE)
450
+ if settings
451
+ else DEFAULT_TEMPERATURE
452
+ )
453
+ top_p = settings.get("top_p", DEFAULT_TOP_P) if settings else DEFAULT_TOP_P
454
+ eos_id = eos_id if eos_id is not None else self.config.tokenizer.eos_id
455
+ lora = (
456
+ variant_state_dict(settings["variant"], device=self.device)
457
+ if settings is not None and "variant" in settings
458
+ else None
459
+ )
460
+
461
+ _, _, next_token, pos = self._prefill_prompt(
462
+ prompt_tokens,
463
+ pos,
464
+ temperature,
465
+ top_p,
466
+ spatial_refs,
467
+ attn_mask=attn_mask,
468
+ lora=lora,
469
+ )
470
+
471
+ def generator(next_token, pos):
472
+ mask = torch.zeros(1, 1, 2048, device=self.device, dtype=torch.bool)
473
+ mask[:, :, :pos] = 1
474
+ pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long)
475
+ generated_tokens = 0
476
+
477
+ # For properly handling token streaming with Unicode
478
+ token_cache = []
479
+ print_len = 0
480
+
481
+ while (
482
+ next_token_id := next_token.item()
483
+ ) != eos_id and generated_tokens < max_tokens:
484
+ # Add token to our cache
485
+ token_cache.append(next_token_id)
486
+
487
+ # Decode all tokens collected so far
488
+ text = self.tokenizer.decode(token_cache)
489
+
490
+ # After a newline, we flush the cache completely
491
+ if text.endswith("\n"):
492
+ printable_text = text[print_len:]
493
+ token_cache = []
494
+ print_len = 0
495
+ if printable_text:
496
+ yield printable_text
497
+ # If the last token is a CJK character, we can safely print it
498
+ elif len(text) > 0 and _is_cjk_char(ord(text[-1])):
499
+ printable_text = text[print_len:]
500
+ print_len += len(printable_text)
501
+ if printable_text:
502
+ yield printable_text
503
+ # Otherwise, only yield up to the last space to avoid cutting words
504
+ else:
505
+ last_space_idx = text.rfind(" ", print_len)
506
+ if last_space_idx >= print_len:
507
+ printable_text = text[print_len : last_space_idx + 1]
508
+ print_len += len(printable_text)
509
+ if printable_text:
510
+ yield printable_text
511
+
512
+ with torch.inference_mode():
513
+ next_emb = text_encoder(next_token, self.text)
514
+ mask[:, :, pos], pos_ids[0] = 1, pos
515
+
516
+ logits_BV, _ = self._decode_one_tok(next_emb, mask, pos_ids, lora)
517
+ logits_BV[:, self.config.tokenizer.answer_id] = float("-inf")
518
+
519
+ pos += 1
520
+
521
+ if temperature == 0:
522
+ next_token = torch.argmax(logits_BV, dim=-1).unsqueeze(
523
+ 1
524
+ ) # (1, 1)
525
+ else:
526
+ probs = torch.softmax(logits_BV / temperature, dim=-1) # (1, V)
527
+ probs = self._apply_top_p(probs, top_p)
528
+ next_token = torch.multinomial(probs, num_samples=1) # (1, 1)
529
+
530
+ generated_tokens += 1
531
+
532
+ # Flush any remaining text in the cache
533
+ if token_cache:
534
+ text = self.tokenizer.decode(token_cache)
535
+ printable_text = text[print_len:]
536
+ if printable_text:
537
+ yield printable_text
538
+
539
+ return generator(next_token, pos)
540
+
541
+ def query(
542
+ self,
543
+ image: Optional[Union[Image.Image, EncodedImage]] = None,
544
+ question: str = None,
545
+ reasoning: bool = False,
546
+ spatial_refs: Optional[SpatialRefs] = None,
547
+ stream: bool = False,
548
+ settings: Optional[TextSamplingSettings] = None,
549
+ ):
550
+ if self.config.tokenizer.templates["query"] is None:
551
+ raise NotImplementedError("Model does not support querying.")
552
+
553
+ if question is None:
554
+ raise ValueError("question must be provided.")
555
+
556
+ if spatial_refs and image is None:
557
+ raise ValueError("spatial_refs can only be used with an image.")
558
+
559
+ attn_mask = self.attn_mask
560
+ if image is not None:
561
+ image = self.encode_image(image, settings)
562
+ self.load_encoded_image(image)
563
+ pos = image.pos
564
+ prompt_toks = self.config.tokenizer.templates["query"]["prefix"]
565
+ else:
566
+ self._setup_caches()
567
+ pos = 0
568
+ prompt_toks = [
569
+ self.config.tokenizer.bos_id
570
+ ] + self.config.tokenizer.templates["query"]["prefix"]
571
+ max_context = self.config.text.max_context
572
+ attn_mask = torch.tril(
573
+ torch.ones(1, 1, max_context, max_context, dtype=torch.bool)
574
+ ).to(self.device)
575
+
576
+ spatial_toks = []
577
+ if spatial_refs:
578
+ for ref in spatial_refs:
579
+ coord_id = self.config.tokenizer.coord_id
580
+ size_id = self.config.tokenizer.size_id
581
+ if len(ref) == 2:
582
+ spatial_toks.extend([coord_id, coord_id])
583
+ else:
584
+ spatial_toks.extend([coord_id, coord_id, size_id])
585
+
586
+ prompt_tokens = [
587
+ prompt_toks
588
+ + spatial_toks
589
+ + self.tokenizer.encode(question).ids
590
+ + self.config.tokenizer.templates["query"]["suffix"]
591
+ ]
592
+
593
+ if reasoning:
594
+ prompt_tokens[0] += [self.config.tokenizer.thinking_id]
595
+ prompt_tokens = torch.tensor(prompt_tokens, device=self.device)
596
+ pos, reasoning_text, reasoning_grounding = self._generate_reasoning(
597
+ prompt_tokens, pos, settings, spatial_refs, attn_mask=attn_mask
598
+ )
599
+ prompt_tokens = [self.config.tokenizer.templates["query"]["suffix"]]
600
+ reasoning_dict = {
601
+ "reasoning": {"text": reasoning_text, "grounding": reasoning_grounding}
602
+ }
603
+ else:
604
+ prompt_tokens[0] += self.config.tokenizer.templates["query"]["suffix"]
605
+ reasoning_dict = {}
606
+
607
+ prompt_tokens = torch.tensor(prompt_tokens, device=self.device)
608
+
609
+ def generator():
610
+ for token in self._generate_answer(
611
+ prompt_tokens, pos, settings, spatial_refs, attn_mask=attn_mask
612
+ ):
613
+ yield token
614
+
615
+ if stream:
616
+ return {**reasoning_dict, "answer": generator()}
617
+ else:
618
+ return {**reasoning_dict, "answer": "".join(list(generator()))}
619
+
620
+ def load_encoded_image(self, encoded_image: EncodedImage):
621
+ for b, (k, v) in zip(self.text.blocks, encoded_image.caches):
622
+ b.kv_cache.k_cache[:, :, : k.size(2), :] = k
623
+ b.kv_cache.v_cache[:, :, : v.size(2), :] = v
624
+
625
+ def caption(
626
+ self,
627
+ image: Union[Image.Image, EncodedImage],
628
+ length: Literal["normal", "short", "long"] = "normal",
629
+ stream: bool = False,
630
+ settings: Optional[TextSamplingSettings] = None,
631
+ ):
632
+ if self.config.tokenizer.templates["caption"] is None:
633
+ raise NotImplementedError("Model does not support captioning.")
634
+ if length not in self.config.tokenizer.templates["caption"]:
635
+ raise ValueError(f"Model does not support caption length '{length}'.")
636
+
637
+ image = self.encode_image(image, settings)
638
+ self.load_encoded_image(image)
639
+
640
+ prompt_tokens = torch.tensor(
641
+ [self.config.tokenizer.templates["caption"][length]], device=self.device
642
+ )
643
+
644
+ def generator():
645
+ for token in self._generate_answer(prompt_tokens, image.pos, settings):
646
+ yield token
647
+
648
+ if stream:
649
+ return {"caption": generator()}
650
+ else:
651
+ return {"caption": "".join(list(generator()))}
652
+
653
+ def _generate_points(
654
+ self,
655
+ hidden: torch.Tensor,
656
+ next_token: torch.Tensor,
657
+ pos: int,
658
+ include_size: bool = True,
659
+ max_objects: int = DEFAULT_MAX_OBJECTS,
660
+ lora: Optional[dict] = None,
661
+ ):
662
+ out = []
663
+ mask = torch.zeros(1, 1, 2048, device=self.device, dtype=torch.bool)
664
+ mask[:, :, :pos] = 1
665
+ pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long)
666
+
667
+ with torch.inference_mode():
668
+ while (
669
+ next_token.item() != self.config.tokenizer.eos_id
670
+ and len(out) < max_objects
671
+ ):
672
+ x_logits = decode_coordinate(hidden, self.region)
673
+ x_center = torch.argmax(x_logits, dim=-1) / x_logits.size(-1)
674
+ next_emb = encode_coordinate(
675
+ x_center.to(dtype=x_logits.dtype), self.region
676
+ ).unsqueeze(0)
677
+
678
+ # Decode y-coordinate
679
+ mask[:, :, pos], pos_ids[0] = 1, pos
680
+ _, hidden = self._decode_one_tok(next_emb, mask, pos_ids, lora)
681
+ pos += 1
682
+ y_logits = decode_coordinate(hidden, self.region)
683
+ y_center = torch.argmax(y_logits, dim=-1) / y_logits.size(-1)
684
+ next_emb = encode_coordinate(
685
+ y_center.to(dtype=y_logits.dtype), self.region
686
+ ).unsqueeze(0)
687
+
688
+ # Decode size
689
+ if include_size:
690
+ mask[:, :, pos], pos_ids[0] = 1, pos
691
+ logits, hidden = self._decode_one_tok(next_emb, mask, pos_ids, lora)
692
+ pos += 1
693
+ size_logits = decode_size(hidden, self.region)
694
+
695
+ # Get bin indices from the logits
696
+ w_bin = torch.argmax(size_logits[0], dim=-1)
697
+ h_bin = torch.argmax(size_logits[1], dim=-1)
698
+
699
+ # Convert from bin indices to actual size values using the inverse of the log-scale mapping
700
+ # Formula: size = 2^((bin / 1023.0) * 10.0 - 10.0)
701
+ w = torch.pow(2.0, (w_bin.float() / 1023.0) * 10.0 - 10.0)
702
+ h = torch.pow(2.0, (h_bin.float() / 1023.0) * 10.0 - 10.0)
703
+
704
+ next_emb = (
705
+ encode_size(
706
+ torch.tensor(
707
+ [w, h], device=self.device, dtype=size_logits.dtype
708
+ ),
709
+ self.region,
710
+ )
711
+ .unsqueeze(0)
712
+ .unsqueeze(0)
713
+ )
714
+
715
+ # Add object
716
+ out.append(
717
+ {
718
+ "x_min": x_center.item() - w.item() / 2,
719
+ "y_min": y_center.item() - h.item() / 2,
720
+ "x_max": x_center.item() + w.item() / 2,
721
+ "y_max": y_center.item() + h.item() / 2,
722
+ }
723
+ )
724
+ else:
725
+ out.append({"x": x_center.item(), "y": y_center.item()})
726
+
727
+ # Decode next token (x-coordinate, or eos)
728
+ mask[:, :, pos], pos_ids[0] = 1, pos
729
+ logits, hidden = self._decode_one_tok(next_emb, mask, pos_ids, lora)
730
+ pos += 1
731
+ next_token = torch.argmax(logits, dim=-1)
732
+
733
+ return out
734
+
735
+ def detect(
736
+ self,
737
+ image: Union[Image.Image, EncodedImage],
738
+ object: str,
739
+ settings: Optional[ObjectSamplingSettings] = None,
740
+ ):
741
+ if self.config.tokenizer.templates["detect"] is None:
742
+ raise NotImplementedError("Model does not support object detection.")
743
+
744
+ image = self.encode_image(image, settings)
745
+ self.load_encoded_image(image)
746
+
747
+ prompt_tokens = torch.tensor(
748
+ [
749
+ self.config.tokenizer.templates["detect"]["prefix"]
750
+ + self.tokenizer.encode(" " + object).ids
751
+ + self.config.tokenizer.templates["detect"]["suffix"]
752
+ ],
753
+ device=self.device,
754
+ )
755
+
756
+ lora = (
757
+ variant_state_dict(settings["variant"], device=self.device)
758
+ if settings is not None and "variant" in settings
759
+ else None
760
+ )
761
+
762
+ _, hidden, next_token, pos = self._prefill_prompt(
763
+ prompt_tokens, image.pos, temperature=0, top_p=0, lora=lora
764
+ )
765
+ hidden = hidden[:, -1:, :]
766
+
767
+ max_objects = (
768
+ settings.get("max_objects", DEFAULT_MAX_OBJECTS)
769
+ if settings
770
+ else DEFAULT_MAX_OBJECTS
771
+ )
772
+ objects = self._generate_points(
773
+ hidden,
774
+ next_token,
775
+ pos,
776
+ include_size=True,
777
+ max_objects=max_objects,
778
+ lora=lora,
779
+ )
780
+
781
+ return {"objects": objects}
782
+
783
+ def point(
784
+ self,
785
+ image: Union[Image.Image, EncodedImage],
786
+ object: str,
787
+ settings: Optional[ObjectSamplingSettings] = None,
788
+ ):
789
+ if self.config.tokenizer.templates["point"] is None:
790
+ raise NotImplementedError("Model does not support pointing.")
791
+
792
+ image = self.encode_image(image, settings)
793
+ self.load_encoded_image(image)
794
+
795
+ prompt_tokens = torch.tensor(
796
+ [
797
+ self.config.tokenizer.templates["point"]["prefix"]
798
+ + self.tokenizer.encode(" " + object).ids
799
+ + self.config.tokenizer.templates["point"]["suffix"]
800
+ ],
801
+ device=self.device,
802
+ )
803
+
804
+ lora = (
805
+ variant_state_dict(settings["variant"], device=self.device)
806
+ if settings is not None and "variant" in settings
807
+ else None
808
+ )
809
+
810
+ _, hidden, next_token, pos = self._prefill_prompt(
811
+ prompt_tokens, image.pos, temperature=0, top_p=0, lora=lora
812
+ )
813
+ hidden = hidden[:, -1:, :]
814
+
815
+ max_objects = (
816
+ settings.get("max_objects", DEFAULT_MAX_OBJECTS)
817
+ if settings
818
+ else DEFAULT_MAX_OBJECTS
819
+ )
820
+ objects = self._generate_points(
821
+ hidden,
822
+ next_token,
823
+ pos,
824
+ include_size=False,
825
+ max_objects=max_objects,
826
+ lora=lora,
827
+ )
828
+
829
+ return {"points": objects}
830
+
831
+ def _detect_gaze(
832
+ self,
833
+ image: EncodedImage,
834
+ source: Tuple[float, float],
835
+ force_detect: bool = False,
836
+ ):
837
+ with torch.inference_mode():
838
+ before_emb = text_encoder(
839
+ torch.tensor(
840
+ [self.tokenizer.encode("\n\nPoint:").ids], device=self.device
841
+ ),
842
+ self.text,
843
+ )
844
+ after_emb = text_encoder(
845
+ torch.tensor(
846
+ [self.tokenizer.encode(" gaze\n\n").ids], device=self.device
847
+ ),
848
+ self.text,
849
+ )
850
+ x_emb = encode_coordinate(
851
+ torch.tensor([[[source[0]]]], device=self.device, dtype=torch.bfloat16),
852
+ self.region,
853
+ )
854
+ y_emb = encode_coordinate(
855
+ torch.tensor([[[source[1]]]], device=self.device, dtype=torch.bfloat16),
856
+ self.region,
857
+ )
858
+
859
+ prompt_emb = torch.cat([before_emb, x_emb, y_emb, after_emb], dim=1)
860
+
861
+ self.load_encoded_image(image)
862
+
863
+ mask = self.attn_mask[:, :, image.pos : image.pos + prompt_emb.size(1), :]
864
+ pos_ids = torch.arange(
865
+ image.pos, image.pos + prompt_emb.size(1), dtype=torch.long
866
+ )
867
+ hidden = self._prefill(prompt_emb, mask, pos_ids, lora=None)
868
+ logits = lm_head(hidden, self.text)
869
+ next_token = torch.argmax(logits, dim=-1)
870
+ pos = image.pos + prompt_emb.size(1)
871
+ hidden = hidden[:, -1:, :]
872
+
873
+ if force_detect:
874
+ next_token = torch.tensor([[0]], device=self.device)
875
+
876
+ if next_token.item() == self.config.tokenizer.eos_id:
877
+ return None
878
+
879
+ gaze = self._generate_points(
880
+ hidden, next_token, pos, include_size=False, max_objects=1
881
+ )
882
+ return gaze[0]
883
+
884
+ def detect_gaze(
885
+ self,
886
+ image: Union[Image.Image, EncodedImage],
887
+ eye: Optional[Tuple[float, float]] = None,
888
+ face: Optional[Dict[str, float]] = None,
889
+ unstable_settings: Dict[str, Any] = {},
890
+ ):
891
+ if "force_detect" in unstable_settings:
892
+ force_detect = unstable_settings["force_detect"]
893
+ else:
894
+ force_detect = False
895
+
896
+ if "prioritize_accuracy" in unstable_settings:
897
+ prioritize_accuracy = unstable_settings["prioritize_accuracy"]
898
+ else:
899
+ prioritize_accuracy = False
900
+
901
+ if not prioritize_accuracy:
902
+ if eye is None:
903
+ raise ValueError("eye must be provided when prioritize_accuracy=False")
904
+ image = self.encode_image(image)
905
+ return {"gaze": self._detect_gaze(image, eye, force_detect=force_detect)}
906
+ else:
907
+ if (
908
+ not isinstance(image, Image.Image)
909
+ and "flip_enc_img" not in unstable_settings
910
+ ):
911
+ raise ValueError(
912
+ "image must be a PIL Image when prioritize_accuracy=True, "
913
+ "or flip_enc_img must be provided"
914
+ )
915
+ if face is None:
916
+ raise ValueError("face must be provided when prioritize_accuracy=True")
917
+
918
+ encoded_image = self.encode_image(image)
919
+ if (
920
+ isinstance(image, Image.Image)
921
+ and "flip_enc_img" not in unstable_settings
922
+ ):
923
+ flipped_pil = image.copy()
924
+ flipped_pil = flipped_pil.transpose(method=Image.FLIP_LEFT_RIGHT)
925
+ encoded_flipped_image = self.encode_image(flipped_pil)
926
+ else:
927
+ encoded_flipped_image = unstable_settings["flip_enc_img"]
928
+
929
+ N = 10
930
+
931
+ detections = [
932
+ self._detect_gaze(
933
+ encoded_image,
934
+ (
935
+ random.uniform(face["x_min"], face["x_max"]),
936
+ random.uniform(face["y_min"], face["y_max"]),
937
+ ),
938
+ force_detect=force_detect,
939
+ )
940
+ for _ in range(N)
941
+ ]
942
+ detections = [
943
+ (gaze["x"], gaze["y"]) for gaze in detections if gaze is not None
944
+ ]
945
+ flipped_detections = [
946
+ self._detect_gaze(
947
+ encoded_flipped_image,
948
+ (
949
+ 1 - random.uniform(face["x_min"], face["x_max"]),
950
+ random.uniform(face["y_min"], face["y_max"]),
951
+ ),
952
+ force_detect=force_detect,
953
+ )
954
+ for _ in range(N)
955
+ ]
956
+ detections.extend(
957
+ [
958
+ (1 - gaze["x"], gaze["y"])
959
+ for gaze in flipped_detections
960
+ if gaze is not None
961
+ ]
962
+ )
963
+
964
+ if len(detections) < N:
965
+ return {"gaze": None}
966
+
967
+ detections = remove_outlier_points(detections)
968
+ mean_gaze = (
969
+ sum(gaze[0] for gaze in detections) / len(detections),
970
+ sum(gaze[1] for gaze in detections) / len(detections),
971
+ )
972
+
973
+ return {"gaze": {"x": mean_gaze[0], "y": mean_gaze[1]}}
974
+
975
+
976
+ def _is_cjk_char(cp):
977
+ """Checks whether CP is the codepoint of a CJK character."""
978
+ # This defines a "chinese character" as anything in the CJK Unicode block:
979
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
980
+ if (
981
+ (cp >= 0x4E00 and cp <= 0x9FFF)
982
+ or (cp >= 0x3400 and cp <= 0x4DBF)
983
+ or (cp >= 0x2F800 and cp <= 0x2FA1F)
984
+ ):
985
+ return True
986
+ return False
moondream2-mmproj-f16.gguf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4cc1cb3660d87ff56432ebeb7884ad35d67c48c7b9f6b2856f305e39c38eed8f
3
+ size 909777984
moondream2-text-model-f16.gguf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e17e9107fb8781629b3c8ce177de57ffeae90fe14adcf7b99f0eef025889696
3
+ size 2839534976
region.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ from typing import List, Tuple, Union
6
+
7
+ from .layers import mlp
8
+
9
+ SpatialRefs = List[Union[Tuple[float, float], Tuple[float, float, float, float]]]
10
+
11
+
12
+ def fourier_features(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
13
+ """
14
+ Applies Fourier feature mapping to input tensor x using frequency matrix w. This
15
+ projects inputs through sinusoidal functions to create higher dimensional features
16
+ that help mitigate spectral bias - the tendency of neural networks to learn
17
+ low-frequency functions more easily than high-frequency ones. By explicitly
18
+ mapping inputs to higher frequencies through sin/cos transformations, we enable
19
+ better learning of fine details and higher frequency patterns.
20
+
21
+ Args:
22
+ x: Input tensor to transform
23
+ w: Matrix of frequencies for the Fourier features transformation
24
+
25
+ Returns:
26
+ Concatenated cosine and sine transformed features as a tensor
27
+ """
28
+ f = 2 * math.pi * x @ w
29
+ return torch.cat([f.cos(), f.sin()], dim=-1)
30
+
31
+
32
+ def encode_coordinate(coord: torch.Tensor, w: nn.Module) -> torch.Tensor:
33
+ """
34
+ Takes as input a tensor containing a single float coordinate value (x or y)
35
+ and encodes it into hidden states for input to the text model.
36
+
37
+ Args:
38
+ coord: Tensor with single float coordinate value
39
+
40
+ Returns:
41
+ Encoded hidden states tensor for input to text model
42
+ """
43
+ return w.coord_encoder(fourier_features(coord, w.coord_features))
44
+
45
+
46
+ def decode_coordinate(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
47
+ """
48
+ Takes as input the last hidden state from the text model and outputs a single logit
49
+ representing either an x or y coordinate prediction.
50
+
51
+ Args:
52
+ hidden_state: The final hidden state tensor from the text model.
53
+
54
+ Returns:
55
+ A single logit representing the predicted coordinate value (x or y)
56
+ """
57
+ return mlp(hidden_state, w.coord_decoder)
58
+
59
+
60
+ def encode_size(size: torch.Tensor, w: nn.Module) -> torch.Tensor:
61
+ """
62
+ Takes a tensor containing width and height values and encodes them into
63
+ hidden states for input to the text model.
64
+
65
+ Args:
66
+ size: Tensor with two floats for width and height
67
+
68
+ Returns:
69
+ Encoded hidden states tensor for input to text model
70
+ """
71
+ return w.size_encoder(fourier_features(size, w.size_features))
72
+
73
+
74
+ def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
75
+ """
76
+ Takes as input the last hidden state from the text model and outputs logits
77
+ for 1024 bins representing width and height in log-scale.
78
+
79
+ The bins are distributed according to the formula:
80
+ bin = (log2(size) + 10.0) / 10.0 * 1023.0
81
+ where size values are clamped to be at least 1/1024.
82
+
83
+ To convert from bin back to size:
84
+ size = 2^((bin / 1023.0) * 10.0 - 10.0)
85
+
86
+ Args:
87
+ hidden_state: The final hidden state tensor from the text model.
88
+
89
+ Returns:
90
+ A tensor containing logits for 1024 bins for width and height.
91
+ Shape is (2, 1024) where the first dimension corresponds to width and height.
92
+ """
93
+ return mlp(hidden_state, w.size_decoder).view(2, -1)
94
+
95
+
96
+ def encode_spatial_refs(spatial_refs: SpatialRefs, w: nn.Module) -> torch.Tensor:
97
+ """
98
+ Takes a list of spatial references (points or regions) and encodes them into
99
+ hidden states for input to the text model.
100
+
101
+ Args:
102
+ spatial_refs: List of spatial references (points or boxes)
103
+ - Points are represented as normalized (x, y) tuples
104
+ - Boxes are represented as normalized (x_min, y_min, x_max, y_max) tuples
105
+
106
+ Returns:
107
+ {"coords": torch.Tensor, "sizes": Optional[torch.Tensor]}
108
+ """
109
+ coords, sizes = [], []
110
+ for ref in spatial_refs:
111
+ if len(ref) == 2:
112
+ coords.append(ref[0])
113
+ coords.append(ref[1])
114
+ else:
115
+ x_c = (ref[0] + ref[2]) / 2
116
+ y_c = (ref[1] + ref[3]) / 2
117
+ width = ref[2] - ref[0]
118
+ height = ref[3] - ref[1]
119
+ coords.append(x_c)
120
+ coords.append(y_c)
121
+ sizes.append([width, height])
122
+
123
+ coords = torch.tensor(
124
+ coords, device=w.coord_features.device, dtype=w.coord_features.dtype
125
+ ).view(-1, 1)
126
+ coords = encode_coordinate(coords, w)
127
+
128
+ if sizes:
129
+ sizes = torch.tensor(
130
+ sizes, device=w.size_features.device, dtype=w.size_features.dtype
131
+ )
132
+ sizes = encode_size(sizes, w)
133
+ else:
134
+ sizes = None
135
+
136
+ return {"coords": coords, "sizes": sizes}
region_model.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .fourier_features import FourierFeatures
4
+
5
+ class RegionModel(nn.Module):
6
+ def __init__(self):
7
+ super().__init__()
8
+
9
+ self.position_features = FourierFeatures(2, 256)
10
+ self.position_encoder = nn.Linear(256, 2048)
11
+ self.size_features = FourierFeatures(2, 256)
12
+ self.size_encoder = nn.Linear(256, 2048)
13
+
14
+ self.position_decoder = nn.Linear(2048, 2)
15
+ self.size_decoder = nn.Linear(2048, 2)
16
+ self.confidence_decoder = nn.Linear(2048, 1)
17
+
18
+ def encode_position(self, position):
19
+ return self.position_encoder(self.position_features(position))
20
+
21
+ def encode_size(self, size):
22
+ return self.size_encoder(self.size_features(size))
23
+
24
+ def decode_position(self, x):
25
+ return self.position_decoder(x)
26
+
27
+ def decode_size(self, x):
28
+ return self.size_decoder(x)
29
+
30
+ def decode_confidence(self, x):
31
+ return self.confidence_decoder(x)
32
+
33
+ def encode(self, position, size):
34
+ return torch.stack(
35
+ [self.encode_position(position), self.encode_size(size)], dim=0
36
+ )
37
+
38
+ def decode(self, position_logits, size_logits):
39
+ return (
40
+ self.decode_position(position_logits),
41
+ self.decode_size(size_logits),
42
+ self.decode_confidence(size_logits),
43
+ )
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ einops
2
+ pyvips-binary==8.16.0
3
+ pyvips==2.2.3
rope.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ethically sourced from https://github.com/xjdr-alt/entropix
2
+
3
+ import torch
4
+
5
+
6
+ def precompute_freqs_cis(
7
+ dim: int,
8
+ end: int,
9
+ theta: float = 10000.0,
10
+ use_scaled: bool = False,
11
+ dtype: torch.dtype = torch.float32,
12
+ ) -> torch.Tensor:
13
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=dtype)[: (dim // 2)] / dim))
14
+ t = torch.arange(end, dtype=dtype).unsqueeze(1)
15
+ freqs = t * freqs.unsqueeze(0)
16
+ freqs = torch.exp(1j * freqs)
17
+ return torch.stack([freqs.real, freqs.imag], dim=-1)
18
+
19
+
20
+ def apply_rotary_emb(
21
+ x: torch.Tensor,
22
+ freqs_cis: torch.Tensor,
23
+ position_ids: torch.Tensor,
24
+ num_heads: int,
25
+ rot_dim: int = 32,
26
+ interleave: bool = False,
27
+ ) -> torch.Tensor:
28
+ assert rot_dim == freqs_cis.shape[-2] * 2
29
+ assert num_heads == x.shape[1]
30
+
31
+ x_rot, x_pass = x[..., :rot_dim], x[..., rot_dim:]
32
+
33
+ if interleave:
34
+ xq_r = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 0]
35
+ xq_i = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 1]
36
+ else:
37
+ d_q = x_rot.shape[-1] // 2
38
+ xq_r, xq_i = x_rot[..., :d_q], x_rot[..., d_q:]
39
+
40
+ freqs_cos = freqs_cis[..., 0][position_ids, :].unsqueeze(0).unsqueeze(0)
41
+ freqs_sin = freqs_cis[..., 1][position_ids, :].unsqueeze(0).unsqueeze(0)
42
+
43
+ # Complex multiplication: (a + bi) * (c + di) = (ac - bd) + (ad + bc)i
44
+ xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
45
+ xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
46
+ xq_out = torch.stack((xq_out_r, xq_out_i), dim=-1).flatten(-2)
47
+
48
+ return torch.cat([xq_out.to(x.dtype), x_pass], dim=-1)
special_tokens_map.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<|endoftext|>",
3
+ "eos_token": "<|endoftext|>",
4
+ "unk_token": "<|endoftext|>"
5
+ }
text.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from torch.nn import functional as F
5
+ from typing import Optional
6
+
7
+ from .layers import layer_norm, mlp, QuantizedLinear
8
+ from .rope import apply_rotary_emb, precompute_freqs_cis
9
+ from .config import TextConfig
10
+
11
+
12
+ def text_encoder(input_ids: torch.Tensor, w: nn.Module):
13
+ return F.embedding(input_ids, w.wte)
14
+
15
+
16
+ def attn(
17
+ x: torch.Tensor,
18
+ w: nn.Module,
19
+ freqs_cis: torch.Tensor,
20
+ kv_cache: nn.Module,
21
+ attn_mask: torch.Tensor,
22
+ n_heads: int,
23
+ n_kv_heads: int,
24
+ position_ids: torch.Tensor,
25
+ lora: Optional[dict],
26
+ ):
27
+ bsz, q_len, d_model = x.shape
28
+ head_dim = d_model // n_heads
29
+
30
+ qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)
31
+ if lora is not None:
32
+ qkv_out += F.linear(F.linear(x, lora["qkv"]["A"]), lora["qkv"]["B"])
33
+ q_dim = n_heads * head_dim
34
+ kv_dim = n_kv_heads * head_dim
35
+ q, k, v = qkv_out.split([q_dim, kv_dim, kv_dim], dim=-1)
36
+ del qkv_out
37
+
38
+ q = q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
39
+ k = k.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
40
+ v = v.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
41
+
42
+ q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)
43
+ k = apply_rotary_emb(k, freqs_cis, position_ids, n_kv_heads)
44
+
45
+ if kv_cache is not None:
46
+ k, v = kv_cache.update(position_ids, k, v)
47
+
48
+ out = F.scaled_dot_product_attention(
49
+ q, k, v, attn_mask=attn_mask, enable_gqa=n_heads != n_kv_heads
50
+ )
51
+ out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
52
+
53
+ out0 = w.proj(out)
54
+ if lora is not None:
55
+ out1 = F.linear(F.linear(x, lora["proj"]["A"]), lora["proj"]["B"])
56
+ out = out0 + out1
57
+ else:
58
+ out = out0
59
+
60
+ return out
61
+
62
+
63
+ def _attn(
64
+ x: torch.Tensor,
65
+ w: torch.Tensor,
66
+ freqs_cis: torch.Tensor,
67
+ attn_mask: torch.Tensor,
68
+ n_heads: int,
69
+ n_kv_heads: int,
70
+ ):
71
+ bsz, q_len, d_model = x.shape
72
+ head_dim = d_model // n_heads
73
+ pos = 0
74
+
75
+ qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)
76
+ q_dim = n_heads * head_dim
77
+ kv_dim = n_kv_heads * head_dim
78
+
79
+ q = qkv_out[..., :q_dim].view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
80
+ k = (
81
+ qkv_out[..., q_dim : q_dim + kv_dim]
82
+ .view(bsz, q_len, n_kv_heads, head_dim)
83
+ .transpose(1, 2)
84
+ )
85
+ v = (
86
+ qkv_out[..., q_dim + kv_dim :]
87
+ .view(bsz, q_len, n_kv_heads, head_dim)
88
+ .transpose(1, 2)
89
+ )
90
+
91
+ position_ids = torch.arange(pos, pos + q_len, dtype=torch.long)
92
+ q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)
93
+ k = apply_rotary_emb(k, freqs_cis, position_ids, n_kv_heads)
94
+ out = F.scaled_dot_product_attention(
95
+ q, k, v, attn_mask=attn_mask, enable_gqa=n_heads != n_kv_heads
96
+ )
97
+ out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
98
+ out = w.proj(out)
99
+ return out
100
+
101
+
102
+ def _produce_hidden(inputs_embeds: torch.Tensor, w: nn.Module, config: TextConfig):
103
+ hidden_BTC = inputs_embeds
104
+
105
+ bsz, q_len, d_model = inputs_embeds.shape
106
+ attn_mask = torch.zeros(q_len, q_len)
107
+ attn_mask[:730, :730] = 1
108
+ for i in range(730, q_len):
109
+ attn_mask[i, : i + 1] = 1
110
+ attn_mask = attn_mask.to(dtype=torch.bool)
111
+
112
+ for i, block in enumerate(w.blocks):
113
+ l_in = layer_norm(hidden_BTC, block.ln)
114
+ l_attn = _attn(
115
+ x=l_in,
116
+ w=block.attn,
117
+ freqs_cis=w.freqs_cis,
118
+ attn_mask=attn_mask,
119
+ n_heads=config.n_heads,
120
+ n_kv_heads=config.n_kv_heads,
121
+ )
122
+ l_mlp = mlp(l_in, block.mlp)
123
+ hidden_BTC = hidden_BTC + l_attn + l_mlp
124
+
125
+ return hidden_BTC
126
+
127
+
128
+ def text_decoder(
129
+ x: torch.Tensor,
130
+ w: nn.Module,
131
+ attn_mask: torch.Tensor,
132
+ position_ids: torch.Tensor,
133
+ config: TextConfig,
134
+ lora: Optional[dict],
135
+ ):
136
+ for i, block in enumerate(w.blocks):
137
+ if lora is not None:
138
+ layer_lora = lora["text"]["blocks"][str(i)]
139
+ mlp_lora = layer_lora["mlp"]
140
+ attn_lora = layer_lora["attn"]
141
+ else:
142
+ mlp_lora = None
143
+ attn_lora = None
144
+
145
+ l_in = layer_norm(x, block.ln)
146
+ l_attn = attn(
147
+ l_in,
148
+ block.attn,
149
+ freqs_cis=w.freqs_cis,
150
+ kv_cache=block.kv_cache,
151
+ attn_mask=attn_mask,
152
+ n_heads=config.n_heads,
153
+ n_kv_heads=config.n_kv_heads,
154
+ position_ids=position_ids,
155
+ lora=attn_lora,
156
+ )
157
+ l_mlp = mlp(l_in, block.mlp, lora=mlp_lora)
158
+ x = x + l_attn + l_mlp
159
+
160
+ return x
161
+
162
+
163
+ def lm_head(hidden_BTC: torch.Tensor, w: nn.Module):
164
+ hidden_BC = hidden_BTC[:, -1, :]
165
+ hidden_BC = layer_norm(hidden_BC, w.post_ln)
166
+ logits = w.lm_head(hidden_BC)
167
+ return logits
168
+
169
+
170
+ def _lm_head(hidden_BTC: torch.Tensor, w: nn.Module):
171
+ hidden_BTC = layer_norm(hidden_BTC, w.post_ln)
172
+ logits = w.lm_head(hidden_BTC)
173
+ return logits
174
+
175
+
176
+ def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
177
+ qkv_dim = int(config.dim * (1 + 2 * config.n_kv_heads / config.n_heads))
178
+ linear_cls = QuantizedLinear if config.group_size is not None else nn.Linear
179
+
180
+ text = nn.ModuleDict(
181
+ {
182
+ "blocks": nn.ModuleList(
183
+ [
184
+ nn.ModuleDict(
185
+ {
186
+ "ln": nn.LayerNorm(config.dim, dtype=dtype),
187
+ "attn": nn.ModuleDict(
188
+ {
189
+ "qkv": linear_cls(config.dim, qkv_dim, dtype=dtype),
190
+ "proj": linear_cls(
191
+ config.dim, config.dim, dtype=dtype
192
+ ),
193
+ }
194
+ ),
195
+ "mlp": nn.ModuleDict(
196
+ {
197
+ "fc1": linear_cls(
198
+ config.dim, config.ff_dim, dtype=dtype
199
+ ),
200
+ "fc2": linear_cls(
201
+ config.ff_dim, config.dim, dtype=dtype
202
+ ),
203
+ }
204
+ ),
205
+ }
206
+ )
207
+ for _ in range(config.n_layers)
208
+ ]
209
+ ),
210
+ "post_ln": nn.LayerNorm(config.dim, dtype=dtype),
211
+ "lm_head": nn.Linear(config.dim, config.vocab_size, dtype=dtype),
212
+ }
213
+ )
214
+ text.wte = nn.Parameter(torch.empty(config.vocab_size, config.dim, dtype=dtype))
215
+ text.register_buffer(
216
+ "freqs_cis",
217
+ precompute_freqs_cis(config.dim // (2 * config.n_heads), config.max_context),
218
+ persistent=False,
219
+ )
220
+
221
+ return text
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "50256": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "50257": {
13
+ "content": " ",
14
+ "lstrip": false,
15
+ "normalized": true,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": false
19
+ },
20
+ "50258": {
21
+ "content": " ",
22
+ "lstrip": false,
23
+ "normalized": true,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": false
27
+ },
28
+ "50259": {
29
+ "content": " ",
30
+ "lstrip": false,
31
+ "normalized": true,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": false
35
+ },
36
+ "50260": {
37
+ "content": " ",
38
+ "lstrip": false,
39
+ "normalized": true,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": false
43
+ },
44
+ "50261": {
45
+ "content": " ",
46
+ "lstrip": false,
47
+ "normalized": true,
48
+ "rstrip": false,
49
+ "single_word": false,
50
+ "special": false
51
+ },
52
+ "50262": {
53
+ "content": " ",
54
+ "lstrip": false,
55
+ "normalized": true,
56
+ "rstrip": false,
57
+ "single_word": false,
58
+ "special": false
59
+ },
60
+ "50263": {
61
+ "content": " ",
62
+ "lstrip": false,
63
+ "normalized": true,
64
+ "rstrip": false,
65
+ "single_word": false,
66
+ "special": false
67
+ },
68
+ "50264": {
69
+ "content": " ",
70
+ "lstrip": false,
71
+ "normalized": true,
72
+ "rstrip": false,
73
+ "single_word": false,
74
+ "special": false
75
+ },
76
+ "50265": {
77
+ "content": " ",
78
+ "lstrip": false,
79
+ "normalized": true,
80
+ "rstrip": false,
81
+ "single_word": false,
82
+ "special": false
83
+ },
84
+ "50266": {
85
+ "content": " ",
86
+ "lstrip": false,
87
+ "normalized": true,
88
+ "rstrip": false,
89
+ "single_word": false,
90
+ "special": false
91
+ },
92
+ "50267": {
93
+ "content": " ",
94
+ "lstrip": false,
95
+ "normalized": true,
96
+ "rstrip": false,
97
+ "single_word": false,
98
+ "special": false
99
+ },
100
+ "50268": {
101
+ "content": " ",
102
+ "lstrip": false,
103
+ "normalized": true,
104
+ "rstrip": false,
105
+ "single_word": false,
106
+ "special": false
107
+ },
108
+ "50269": {
109
+ "content": " ",
110
+ "lstrip": false,
111
+ "normalized": true,
112
+ "rstrip": false,
113
+ "single_word": false,
114
+ "special": false
115
+ },
116
+ "50270": {
117
+ "content": " ",
118
+ "lstrip": false,
119
+ "normalized": true,
120
+ "rstrip": false,
121
+ "single_word": false,
122
+ "special": false
123
+ },
124
+ "50271": {
125
+ "content": " ",
126
+ "lstrip": false,
127
+ "normalized": true,
128
+ "rstrip": false,
129
+ "single_word": false,
130
+ "special": false
131
+ },
132
+ "50272": {
133
+ "content": " ",
134
+ "lstrip": false,
135
+ "normalized": true,
136
+ "rstrip": false,
137
+ "single_word": false,
138
+ "special": false
139
+ },
140
+ "50273": {
141
+ "content": " ",
142
+ "lstrip": false,
143
+ "normalized": true,
144
+ "rstrip": false,
145
+ "single_word": false,
146
+ "special": false
147
+ },
148
+ "50274": {
149
+ "content": " ",
150
+ "lstrip": false,
151
+ "normalized": true,
152
+ "rstrip": false,
153
+ "single_word": false,
154
+ "special": false
155
+ },
156
+ "50275": {
157
+ "content": " ",
158
+ "lstrip": false,
159
+ "normalized": true,
160
+ "rstrip": false,
161
+ "single_word": false,
162
+ "special": false
163
+ },
164
+ "50276": {
165
+ "content": " ",
166
+ "lstrip": false,
167
+ "normalized": true,
168
+ "rstrip": false,
169
+ "single_word": false,
170
+ "special": false
171
+ },
172
+ "50277": {
173
+ "content": " ",
174
+ "lstrip": false,
175
+ "normalized": true,
176
+ "rstrip": false,
177
+ "single_word": false,
178
+ "special": false
179
+ },
180
+ "50278": {
181
+ "content": " ",
182
+ "lstrip": false,
183
+ "normalized": true,
184
+ "rstrip": false,
185
+ "single_word": false,
186
+ "special": false
187
+ },
188
+ "50279": {
189
+ "content": " ",
190
+ "lstrip": false,
191
+ "normalized": true,
192
+ "rstrip": false,
193
+ "single_word": false,
194
+ "special": false
195
+ },
196
+ "50280": {
197
+ "content": " ",
198
+ "lstrip": false,
199
+ "normalized": true,
200
+ "rstrip": false,
201
+ "single_word": false,
202
+ "special": false
203
+ },
204
+ "50281": {
205
+ "content": " ",
206
+ "lstrip": false,
207
+ "normalized": true,
208
+ "rstrip": false,
209
+ "single_word": false,
210
+ "special": false
211
+ },
212
+ "50282": {
213
+ "content": " ",
214
+ "lstrip": false,
215
+ "normalized": true,
216
+ "rstrip": false,
217
+ "single_word": false,
218
+ "special": false
219
+ },
220
+ "50283": {
221
+ "content": " ",
222
+ "lstrip": false,
223
+ "normalized": true,
224
+ "rstrip": false,
225
+ "single_word": false,
226
+ "special": false
227
+ },
228
+ "50284": {
229
+ "content": " ",
230
+ "lstrip": false,
231
+ "normalized": true,
232
+ "rstrip": false,
233
+ "single_word": false,
234
+ "special": false
235
+ },
236
+ "50285": {
237
+ "content": " ",
238
+ "lstrip": false,
239
+ "normalized": true,
240
+ "rstrip": false,
241
+ "single_word": false,
242
+ "special": false
243
+ },
244
+ "50286": {
245
+ "content": " ",
246
+ "lstrip": false,
247
+ "normalized": true,
248
+ "rstrip": false,
249
+ "single_word": false,
250
+ "special": false
251
+ },
252
+ "50287": {
253
+ "content": "\t\t\t\t\t\t\t\t\t",
254
+ "lstrip": false,
255
+ "normalized": true,
256
+ "rstrip": false,
257
+ "single_word": false,
258
+ "special": false
259
+ },
260
+ "50288": {
261
+ "content": "\t\t\t\t\t\t\t\t",
262
+ "lstrip": false,
263
+ "normalized": true,
264
+ "rstrip": false,
265
+ "single_word": false,
266
+ "special": false
267
+ },
268
+ "50289": {
269
+ "content": "\t\t\t\t\t\t\t",
270
+ "lstrip": false,
271
+ "normalized": true,
272
+ "rstrip": false,
273
+ "single_word": false,
274
+ "special": false
275
+ },
276
+ "50290": {
277
+ "content": "\t\t\t\t\t\t",
278
+ "lstrip": false,
279
+ "normalized": true,
280
+ "rstrip": false,
281
+ "single_word": false,
282
+ "special": false
283
+ },
284
+ "50291": {
285
+ "content": "\t\t\t\t\t",
286
+ "lstrip": false,
287
+ "normalized": true,
288
+ "rstrip": false,
289
+ "single_word": false,
290
+ "special": false
291
+ },
292
+ "50292": {
293
+ "content": "\t\t\t\t",
294
+ "lstrip": false,
295
+ "normalized": true,
296
+ "rstrip": false,
297
+ "single_word": false,
298
+ "special": false
299
+ },
300
+ "50293": {
301
+ "content": "\t\t\t",
302
+ "lstrip": false,
303
+ "normalized": true,
304
+ "rstrip": false,
305
+ "single_word": false,
306
+ "special": false
307
+ },
308
+ "50294": {
309
+ "content": "\t\t",
310
+ "lstrip": false,
311
+ "normalized": true,
312
+ "rstrip": false,
313
+ "single_word": false,
314
+ "special": false
315
+ }
316
+ },
317
+ "bos_token": "<|endoftext|>",
318
+ "clean_up_tokenization_spaces": true,
319
+ "eos_token": "<|endoftext|>",
320
+ "model_max_length": 2048,
321
+ "tokenizer_class": "CodeGenTokenizer",
322
+ "unk_token": "<|endoftext|>"
323
+ }
utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def remove_outlier_points(points_tuples, k_nearest=2, threshold=2.0):
5
+ """
6
+ Robust outlier detection for list of (x,y) tuples.
7
+ Only requires numpy.
8
+
9
+ Args:
10
+ points_tuples: list of (x,y) tuples
11
+ k_nearest: number of neighbors to consider
12
+ threshold: multiplier for median distance
13
+
14
+ Returns:
15
+ list: filtered list of (x,y) tuples with outliers removed
16
+ list: list of booleans indicating which points were kept (True = kept)
17
+ """
18
+ points = np.array(points_tuples)
19
+ n_points = len(points)
20
+
21
+ # Calculate pairwise distances manually
22
+ dist_matrix = np.zeros((n_points, n_points))
23
+ for i in range(n_points):
24
+ for j in range(i + 1, n_points):
25
+ # Euclidean distance between points i and j
26
+ dist = np.sqrt(np.sum((points[i] - points[j]) ** 2))
27
+ dist_matrix[i, j] = dist
28
+ dist_matrix[j, i] = dist
29
+
30
+ # Get k nearest neighbors' distances
31
+ k = min(k_nearest, n_points - 1)
32
+ neighbor_distances = np.partition(dist_matrix, k, axis=1)[:, :k]
33
+ avg_neighbor_dist = np.mean(neighbor_distances, axis=1)
34
+
35
+ # Calculate mask using median distance
36
+ median_dist = np.median(avg_neighbor_dist)
37
+ mask = avg_neighbor_dist <= threshold * median_dist
38
+
39
+ # Return filtered tuples and mask
40
+ filtered_tuples = [t for t, m in zip(points_tuples, mask) if m]
41
+ return filtered_tuples
versions.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2024-03-04
2
+ 2024-03-06
3
+ 2024-03-13
4
+ 2024-04-02
5
+ 2024-05-08
6
+ 2024-05-20
7
+ 2024-07-23
8
+ 2024-08-26
9
+ 2025-01-09
10
+ 2025-03-27
11
+ 2025-04-14
12
+ 2025-06-21
vision.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+
6
+ from typing import Union, Tuple
7
+ from PIL import Image
8
+
9
+ from .layers import attn, layer_norm, mlp
10
+ from .image_crops import overlap_crop_image
11
+ from .config import VisionConfig
12
+
13
+ if torch.backends.mps.is_available():
14
+ # Non-divisible input sizes are not implemented on MPS device yet.
15
+ # https://github.com/pytorch/pytorch/issues/96056
16
+ def adaptive_avg_pool2d(input, output_size):
17
+ return F.adaptive_avg_pool2d(input.to("cpu"), output_size).to("mps")
18
+
19
+ else:
20
+ adaptive_avg_pool2d = F.adaptive_avg_pool2d
21
+
22
+ DeviceLike = Union[str, torch.device, int]
23
+
24
+
25
+ def prepare_crops(
26
+ image: Image.Image, config: VisionConfig, device: DeviceLike
27
+ ) -> Tuple[torch.Tensor, Tuple[int, int]]:
28
+ np_image = np.array(image.convert("RGB"))
29
+ overlap_crops = overlap_crop_image(
30
+ np_image, max_crops=config.max_crops, overlap_margin=config.overlap_margin
31
+ )
32
+ all_crops = overlap_crops["crops"]
33
+ all_crops = np.transpose(all_crops, (0, 3, 1, 2))
34
+ all_crops = (
35
+ torch.from_numpy(all_crops)
36
+ .to(device=device, dtype=torch.bfloat16)
37
+ .div_(255.0)
38
+ .sub_(0.5)
39
+ .div_(0.5)
40
+ )
41
+ return all_crops, overlap_crops["tiling"]
42
+
43
+
44
+ def create_patches(x, patch_size):
45
+ # Original shape: [B, C, H, W]
46
+ B, C, H, W = x.shape
47
+ P1 = P2 = patch_size
48
+
49
+ # Step 1: Split H and W dimensions into patches
50
+ # [B, C, H/P1, P1, W/P2, P2]
51
+ x = x.reshape(B, C, H // P1, P1, W // P2, P2)
52
+
53
+ # Step 2: Rearrange dimensions to match target shape
54
+ # [B, H/P1, W/P2, C, P1, P2]
55
+ x = x.permute(0, 2, 4, 1, 3, 5)
56
+
57
+ # Step 3: Combine dimensions to get final shape
58
+ # [B, (H/P1)*(W/P2), C*P1*P2]
59
+ x = x.reshape(B, (H // P1) * (W // P2), C * P1 * P2)
60
+
61
+ return x
62
+
63
+
64
+ def vision_encoder(input_BCHW: torch.Tensor, w: nn.Module, config: VisionConfig):
65
+ x = create_patches(input_BCHW, config.enc_patch_size)
66
+
67
+ x = w.patch_emb(x)
68
+ x = x + w.pos_emb
69
+ for block in w.blocks:
70
+ x = x + attn(layer_norm(x, block.ln1), block.attn, n_heads=config.enc_n_heads)
71
+ x = x + mlp(layer_norm(x, block.ln2), block.mlp)
72
+ x = layer_norm(x, w.post_ln)
73
+
74
+ return x
75
+
76
+
77
+ def vision_projection(
78
+ global_features: torch.Tensor,
79
+ reconstructed: torch.Tensor,
80
+ w: nn.Module,
81
+ config: VisionConfig,
82
+ ):
83
+ reconstructed = reconstructed.permute(2, 0, 1)
84
+ reconstructed = adaptive_avg_pool2d(
85
+ reconstructed, output_size=(config.enc_n_layers, config.enc_n_layers)
86
+ )
87
+ reconstructed = reconstructed.permute(1, 2, 0).view(729, config.enc_dim)
88
+ final_features = torch.cat([global_features, reconstructed], dim=-1)
89
+ return mlp(final_features, w.proj_mlp)
90
+
91
+
92
+ def build_vision_model(config: VisionConfig, dtype: torch.dtype):
93
+ patch_dim = config.enc_patch_size * config.enc_patch_size * config.in_channels
94
+ grid_size = config.crop_size // config.enc_patch_size
95
+ num_patches = grid_size * grid_size
96
+
97
+ vision = nn.ModuleDict(
98
+ {
99
+ "patch_emb": nn.Linear(patch_dim, config.enc_dim, dtype=dtype),
100
+ "blocks": nn.ModuleList(
101
+ [
102
+ nn.ModuleDict(
103
+ {
104
+ "ln1": nn.LayerNorm(config.enc_dim, dtype=dtype),
105
+ "attn": nn.ModuleDict(
106
+ {
107
+ "qkv": nn.Linear(
108
+ config.enc_dim, 3 * config.enc_dim, dtype=dtype
109
+ ),
110
+ "proj": nn.Linear(
111
+ config.enc_dim, config.enc_dim, dtype=dtype
112
+ ),
113
+ }
114
+ ),
115
+ "ln2": nn.LayerNorm(config.enc_dim, dtype=dtype),
116
+ "mlp": nn.ModuleDict(
117
+ {
118
+ "fc1": nn.Linear(
119
+ config.enc_dim, config.enc_ff_dim, dtype=dtype
120
+ ),
121
+ "fc2": nn.Linear(
122
+ config.enc_ff_dim, config.enc_dim, dtype=dtype
123
+ ),
124
+ }
125
+ ),
126
+ }
127
+ )
128
+ for _ in range(config.enc_n_layers)
129
+ ]
130
+ ),
131
+ "post_ln": nn.LayerNorm(config.enc_dim, dtype=dtype),
132
+ "proj_mlp": nn.ModuleDict(
133
+ {
134
+ "fc1": nn.Linear(
135
+ config.enc_dim * 2, config.proj_inner_dim, dtype=dtype
136
+ ),
137
+ "fc2": nn.Linear(
138
+ config.proj_inner_dim, config.proj_out_dim, dtype=dtype
139
+ ),
140
+ }
141
+ ),
142
+ }
143
+ )
144
+ vision.pos_emb = nn.Parameter(
145
+ torch.zeros(1, num_patches, config.enc_dim, dtype=dtype)
146
+ )
147
+ return vision
vision_encoder.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import PIL.Image
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+ from einops import rearrange
8
+ import PIL
9
+ from torchvision.transforms.v2 import (
10
+ Compose,
11
+ Resize,
12
+ InterpolationMode,
13
+ ToImage,
14
+ ToDtype,
15
+ Normalize,
16
+ )
17
+ from transformers.utils import is_flash_attn_2_available
18
+
19
+ try:
20
+ if is_flash_attn_2_available():
21
+ from flash_attn.modules.mha import FlashSelfAttention
22
+ else:
23
+ FlashSelfAttention = None
24
+ except ImportError:
25
+ FlashSelfAttention = None
26
+
27
+
28
+ class Attention(nn.Module):
29
+
30
+ def __init__(self, dim, num_heads=16, use_flash_attn=False):
31
+ super().__init__()
32
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
33
+
34
+ self.num_heads = num_heads
35
+ self.head_dim = dim // num_heads
36
+
37
+ self.qkv = nn.Linear(dim, dim * 3)
38
+ self.proj = nn.Linear(dim, dim)
39
+
40
+ if use_flash_attn and FlashSelfAttention is not None:
41
+ self.flash_attn = FlashSelfAttention()
42
+ else:
43
+ self.flash_attn = None
44
+
45
+ torch.nn.init.kaiming_normal_(
46
+ self.qkv.weight, mode="fan_in", nonlinearity="relu"
47
+ )
48
+ torch.nn.init.kaiming_normal_(
49
+ self.proj.weight, mode="fan_in", nonlinearity="relu"
50
+ )
51
+
52
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
53
+ if self.flash_attn is not None:
54
+ qkv = self.qkv(x)
55
+ qkv = rearrange(
56
+ qkv, "... (three h d) -> ... three h d", three=3, h=self.num_heads
57
+ )
58
+ attn_output = self.flash_attn(qkv)
59
+ output = rearrange(attn_output, "... h d -> ... (h d)")
60
+ output = self.proj(output)
61
+ return output
62
+ else:
63
+ B, N, C = x.shape
64
+ qkv = (
65
+ self.qkv(x)
66
+ .reshape(B, N, 3, self.num_heads, self.head_dim)
67
+ .permute(2, 0, 3, 1, 4)
68
+ )
69
+ q, k, v = qkv.unbind(0)
70
+
71
+ x = F.scaled_dot_product_attention(q, k, v)
72
+
73
+ x = x.transpose(1, 2).reshape(B, N, C)
74
+ x = self.proj(x)
75
+ return x
76
+
77
+
78
+ class VitBlock(nn.Module):
79
+
80
+ def __init__(self, embed_dim, use_flash_attn=False):
81
+ super().__init__()
82
+ self.attn = Attention(embed_dim, use_flash_attn=use_flash_attn)
83
+ self.mlp = MLP(embed_dim, 4304)
84
+ self.norm1 = nn.LayerNorm(embed_dim)
85
+ self.norm2 = nn.LayerNorm(embed_dim)
86
+
87
+ def forward(self, x):
88
+ x = x + self.attn(self.norm1(x))
89
+ x = x + self.mlp(self.norm2(x))
90
+ return x
91
+
92
+
93
+ class VisionTransformer(nn.Module):
94
+
95
+ def __init__(self, use_flash_attn=False):
96
+ super().__init__()
97
+
98
+ embed_len = 729
99
+ embed_dim = 1152
100
+
101
+ self.patch_embed = LinearPatchEmbedding()
102
+ self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
103
+ self.blocks = nn.Sequential(
104
+ *[VitBlock(embed_dim, use_flash_attn=use_flash_attn) for _ in range(27)]
105
+ )
106
+ self.norm = nn.LayerNorm(embed_dim)
107
+
108
+ def forward(self, x):
109
+ x = self.patch_embed(x)
110
+ x = x + self.pos_embed
111
+ for block in self.blocks:
112
+ x = block(x)
113
+ return self.norm(x)
114
+
115
+
116
+ class EncoderWrapper(nn.Module):
117
+
118
+ def __init__(self, use_flash_attn=False):
119
+ super().__init__()
120
+ self.model = nn.ModuleDict({"visual": VisionTransformer(use_flash_attn)})
121
+
122
+ def forward(self, x):
123
+ return self.model["visual"](x)
124
+
125
+
126
+ class LinearPatchEmbedding(nn.Module):
127
+
128
+ def __init__(self):
129
+ super().__init__()
130
+ self.linear = nn.Linear(588, 1152)
131
+
132
+ def forward(self, x):
133
+ b, c, hp1, wp2 = x.shape
134
+ p1, p2 = 14, 14
135
+ h, w = hp1 // p1, wp2 // p2
136
+ x = x.reshape(b, c, h, p1, w, p2)
137
+ x = x.permute(0, 2, 4, 1, 3, 5)
138
+ x = x.reshape(b, h * w, c * p1 * p2)
139
+
140
+ return self.linear(x)
141
+
142
+
143
+ class MLP(nn.Module):
144
+ def __init__(
145
+ self,
146
+ in_features: int,
147
+ hidden_features: int = None,
148
+ out_features: int = None,
149
+ ) -> None:
150
+ super().__init__()
151
+ out_features = out_features or in_features
152
+ hidden_features = hidden_features or in_features
153
+ self.fc1 = nn.Linear(in_features, hidden_features)
154
+ self.act = nn.GELU(approximate="tanh")
155
+ self.fc2 = nn.Linear(hidden_features, out_features)
156
+
157
+ torch.nn.init.kaiming_normal_(
158
+ self.fc1.weight, mode="fan_in", nonlinearity="relu"
159
+ )
160
+ torch.nn.init.kaiming_normal_(
161
+ self.fc2.weight, mode="fan_in", nonlinearity="relu"
162
+ )
163
+
164
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
165
+ x = self.fc1(x)
166
+ x = self.act(x)
167
+ x = self.fc2(x)
168
+ return x
169
+
170
+
171
+ class VisionProjection(nn.Module):
172
+ def __init__(self):
173
+ super().__init__()
174
+
175
+ image_embedding_dim = 1152
176
+ model_dim = 2048
177
+ hidden_dim = model_dim * 4
178
+
179
+ self.mlp = MLP(image_embedding_dim * 2, hidden_dim, model_dim)
180
+
181
+ @property
182
+ def device(self):
183
+ return self.mlp.fc1.weight.device
184
+
185
+ def forward(self, x):
186
+ return self.mlp(x)
187
+
188
+
189
+ def create_patches(image, patch_size=(378, 378)):
190
+ assert image.dim() == 3, "Image must be in CHW format"
191
+
192
+ _, height, width = image.shape # Channels, Height, Width
193
+ patch_height, patch_width = patch_size
194
+
195
+ if height == patch_height and width == patch_width:
196
+ return []
197
+
198
+ # Iterate over the image and create patches
199
+ patches = []
200
+ for i in range(0, height, patch_height):
201
+ row_patches = []
202
+ for j in range(0, width, patch_width):
203
+ patch = image[:, i : i + patch_height, j : j + patch_width]
204
+ row_patches.append(patch)
205
+ patches.append(torch.stack(row_patches))
206
+ return patches
207
+
208
+
209
+ class VisionEncoder(nn.Module):
210
+
211
+ def __init__(self, use_flash_attn=False):
212
+ super().__init__()
213
+
214
+ self.encoder = EncoderWrapper(use_flash_attn)
215
+ self.projection = VisionProjection()
216
+ self.supported_sizes = [(378, 378), (378, 756), (756, 378), (756, 756)]
217
+
218
+ @property
219
+ def device(self):
220
+ return self.projection.mlp.fc1.weight.device
221
+
222
+ @property
223
+ def dtype(self):
224
+ return self.projection.mlp.fc1.weight.dtype
225
+
226
+ def preprocess(self, image: PIL.Image.Image):
227
+ width, height = image.size
228
+ max_dim = max(width, height)
229
+ if max_dim < 512:
230
+ im_size = (378, 378)
231
+ else:
232
+ aspect_ratio = width / height
233
+ im_size = min(
234
+ self.supported_sizes,
235
+ key=lambda size: (
236
+ abs((size[1] / size[0]) - aspect_ratio),
237
+ abs(size[0] - width) + abs(size[1] - height),
238
+ ),
239
+ )
240
+
241
+ return Compose(
242
+ [
243
+ Resize(size=im_size, interpolation=InterpolationMode.BICUBIC),
244
+ ToImage(),
245
+ ToDtype(torch.float32, scale=True),
246
+ Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
247
+ ]
248
+ )(image)
249
+
250
+ def forward(
251
+ self, images: Union[PIL.Image.Image, list[PIL.Image.Image], torch.Tensor]
252
+ ) -> torch.Tensor:
253
+ im_list = None
254
+ if isinstance(images, torch.Tensor):
255
+ # Input must have dimensions (B, C, H, W)
256
+ assert (
257
+ len(images.shape) == 4
258
+ ), "Tensor input must have dimensions (B, C, H, W)"
259
+ im_list = list(images)
260
+ elif isinstance(images, PIL.Image.Image):
261
+ im_list = [images]
262
+ elif isinstance(images, list):
263
+ im_list = images
264
+ else:
265
+ raise ValueError(
266
+ "Input must be a PIL image, list of PIL images, or a tensor"
267
+ )
268
+
269
+ # Preprocess unless the images are already tensors (indicating that
270
+ # they have already been preprocessed)
271
+ if not isinstance(im_list[0], torch.Tensor):
272
+ im_list = [self.preprocess(im.convert("RGB")) for im in im_list]
273
+
274
+ patches = [create_patches(im) for im in im_list]
275
+ flat_patches = [patch for image_patches in patches for patch in image_patches]
276
+
277
+ # Images may be variable size, and need to be resized to a common size after
278
+ # creating patches.
279
+ resized_images = [
280
+ F.interpolate(im.unsqueeze(0), size=(378, 378), mode="bilinear")
281
+ for im in im_list
282
+ ]
283
+
284
+ combined_images = torch.cat([*resized_images, *flat_patches], dim=0)
285
+ combined_images = combined_images.to(self.device, dtype=self.dtype)
286
+
287
+ combined_features = self.encoder(combined_images)
288
+
289
+ full_img_features = combined_features[: len(im_list)]
290
+ patch_features = (
291
+ combined_features[len(im_list) :].transpose(1, 2).view(-1, 1152, 27, 27)
292
+ )
293
+
294
+ # Reshape patch features back to their original structure
295
+ reshaped_patch_features = []
296
+ patch_idx = 0
297
+ for i, patch_set in enumerate(patches):
298
+ if len(patch_set) == 0:
299
+ reshaped_patch_features.append(
300
+ full_img_features[i].transpose(0, 1).view(1152, 27, 27)
301
+ )
302
+ else:
303
+ sample_features = []
304
+ for row_patches in patch_set:
305
+ row_len = len(row_patches)
306
+ row_features = patch_features[
307
+ patch_idx : patch_idx + row_len
308
+ ] # row_len, T, C
309
+ row_features = torch.cat(
310
+ list(row_features), dim=2
311
+ ) # T, C * row_len
312
+ patch_idx += row_len
313
+ sample_features.append(row_features)
314
+ sample_features = torch.cat(sample_features, dim=1)
315
+ sample_features = F.interpolate(
316
+ sample_features.unsqueeze(0), size=(27, 27), mode="bilinear"
317
+ ).squeeze(0)
318
+ reshaped_patch_features.append(sample_features)
319
+ reshaped_patch_features = (
320
+ torch.stack(reshaped_patch_features).view(-1, 1152, 729).transpose(1, 2)
321
+ )
322
+
323
+ final_features = torch.cat([full_img_features, reshaped_patch_features], dim=2)
324
+
325
+ return self.projection(final_features)
vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
weights.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import safetensors
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from contextlib import contextmanager
6
+ from dataclasses import dataclass
7
+ from typing import Callable, List
8
+
9
+ from .layers import AttentionWeights, LayerNormWeights, LinearWeights, MLPWeights
10
+
11
+
12
+ @dataclass
13
+ class VisionBlock:
14
+ ln1: LayerNormWeights
15
+ attn: AttentionWeights
16
+ ln2: LayerNormWeights
17
+ mlp: MLPWeights
18
+
19
+
20
+ @dataclass
21
+ class VisionModel:
22
+ patch_emb: LinearWeights
23
+ pos_emb: torch.Tensor
24
+ blocks: List[VisionBlock]
25
+ post_ln: LayerNormWeights
26
+ proj_mlp: MLPWeights
27
+
28
+
29
+ @dataclass
30
+ class TextBlock:
31
+ ln: LayerNormWeights
32
+ attn: AttentionWeights
33
+ mlp: MLPWeights
34
+
35
+
36
+ @dataclass
37
+ class TextModel:
38
+ wte: torch.Tensor
39
+ blocks: List[TextBlock]
40
+ post_ln: LayerNormWeights
41
+ lm_head: LinearWeights
42
+
43
+
44
+ @dataclass
45
+ class RegionModel:
46
+ coord_features: torch.Tensor
47
+ coord_encoder: LinearWeights
48
+ coord_decoder: MLPWeights
49
+ size_features: torch.Tensor
50
+ size_encoder: LinearWeights
51
+ size_decoder: MLPWeights
52
+
53
+
54
+ @dataclass
55
+ class MoondreamModel:
56
+ vision: VisionModel
57
+ text: TextModel
58
+ region: RegionModel
59
+
60
+
61
+ @contextmanager
62
+ def safetensors_open(safetensors_file: str):
63
+ """
64
+ Simplify interfacing with safetensors files. Eliminates the need to ignore
65
+ type errors when using the `safe_open` function.
66
+ """
67
+ with safetensors.safe_open(
68
+ safetensors_file, framework="pt"
69
+ ) as st: # pyright: ignore
70
+
71
+ def get_tensor(name: str) -> torch.Tensor:
72
+ return st.get_tensor(name)
73
+
74
+ def get_keys() -> List[str]:
75
+ return st.keys()
76
+
77
+ get_tensor.keys = get_keys
78
+
79
+ yield get_tensor
80
+
81
+
82
+ def _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.Module) -> None:
83
+ """Internal function to load weights using a tensor getter function."""
84
+ model = model.to(dtype=torch.float16)
85
+
86
+ # Vision Model
87
+ model.vision["patch_emb"].weight.data.copy_(
88
+ get_tensor("vision_encoder.encoder.model.visual.patch_embed.linear.weight")
89
+ )
90
+ model.vision["patch_emb"].bias.data.copy_(
91
+ get_tensor("vision_encoder.encoder.model.visual.patch_embed.linear.bias")
92
+ )
93
+ model.vision.pos_emb.data.copy_(
94
+ get_tensor("vision_encoder.encoder.model.visual.pos_embed")
95
+ )
96
+
97
+ for i in range(len(model.vision["blocks"])):
98
+ prefix = f"vision_encoder.encoder.model.visual.blocks.{i}"
99
+
100
+ # Layer norms
101
+ model.vision["blocks"][i]["ln1"].weight.data.copy_(
102
+ get_tensor(f"{prefix}.norm1.weight")
103
+ )
104
+ model.vision["blocks"][i]["ln1"].bias.data.copy_(
105
+ get_tensor(f"{prefix}.norm1.bias")
106
+ )
107
+ model.vision["blocks"][i]["ln2"].weight.data.copy_(
108
+ get_tensor(f"{prefix}.norm2.weight")
109
+ )
110
+ model.vision["blocks"][i]["ln2"].bias.data.copy_(
111
+ get_tensor(f"{prefix}.norm2.bias")
112
+ )
113
+
114
+ # Attention
115
+ model.vision["blocks"][i]["attn"]["qkv"].weight.data.copy_(
116
+ get_tensor(f"{prefix}.attn.qkv.weight")
117
+ )
118
+ model.vision["blocks"][i]["attn"]["qkv"].bias.data.copy_(
119
+ get_tensor(f"{prefix}.attn.qkv.bias")
120
+ )
121
+ model.vision["blocks"][i]["attn"]["proj"].weight.data.copy_(
122
+ get_tensor(f"{prefix}.attn.proj.weight")
123
+ )
124
+ model.vision["blocks"][i]["attn"]["proj"].bias.data.copy_(
125
+ get_tensor(f"{prefix}.attn.proj.bias")
126
+ )
127
+
128
+ # MLP
129
+ model.vision["blocks"][i]["mlp"]["fc1"].weight.data.copy_(
130
+ get_tensor(f"{prefix}.mlp.fc1.weight")
131
+ )
132
+ model.vision["blocks"][i]["mlp"]["fc1"].bias.data.copy_(
133
+ get_tensor(f"{prefix}.mlp.fc1.bias")
134
+ )
135
+ model.vision["blocks"][i]["mlp"]["fc2"].weight.data.copy_(
136
+ get_tensor(f"{prefix}.mlp.fc2.weight")
137
+ )
138
+ model.vision["blocks"][i]["mlp"]["fc2"].bias.data.copy_(
139
+ get_tensor(f"{prefix}.mlp.fc2.bias")
140
+ )
141
+
142
+ model.vision["post_ln"].weight.data.copy_(
143
+ get_tensor("vision_encoder.encoder.model.visual.norm.weight")
144
+ )
145
+ model.vision["post_ln"].bias.data.copy_(
146
+ get_tensor("vision_encoder.encoder.model.visual.norm.bias")
147
+ )
148
+
149
+ model.vision["proj_mlp"]["fc1"].weight.data.copy_(
150
+ get_tensor("vision_encoder.projection.mlp.fc1.weight")
151
+ )
152
+ model.vision["proj_mlp"]["fc1"].bias.data.copy_(
153
+ get_tensor("vision_encoder.projection.mlp.fc1.bias")
154
+ )
155
+ model.vision["proj_mlp"]["fc2"].weight.data.copy_(
156
+ get_tensor("vision_encoder.projection.mlp.fc2.weight")
157
+ )
158
+ model.vision["proj_mlp"]["fc2"].bias.data.copy_(
159
+ get_tensor("vision_encoder.projection.mlp.fc2.bias")
160
+ )
161
+
162
+ # Text Model
163
+ model.text.wte.data.copy_(get_tensor("text_model.transformer.embd.wte.weight"))
164
+
165
+ for i in range(len(model.text["blocks"])):
166
+ prefix = f"text_model.transformer.h.{i}"
167
+
168
+ # Layer norm
169
+ model.text["blocks"][i]["ln"].weight.data.copy_(
170
+ get_tensor(f"{prefix}.ln.weight")
171
+ )
172
+ model.text["blocks"][i]["ln"].bias.data.copy_(get_tensor(f"{prefix}.ln.bias"))
173
+
174
+ # Attention
175
+ model.text["blocks"][i]["attn"]["qkv"].weight.data.copy_(
176
+ get_tensor(f"{prefix}.mixer.Wqkv.weight")
177
+ )
178
+ model.text["blocks"][i]["attn"]["qkv"].bias.data.copy_(
179
+ get_tensor(f"{prefix}.mixer.Wqkv.bias")
180
+ )
181
+ model.text["blocks"][i]["attn"]["proj"].weight.data.copy_(
182
+ get_tensor(f"{prefix}.mixer.out_proj.weight")
183
+ )
184
+ model.text["blocks"][i]["attn"]["proj"].bias.data.copy_(
185
+ get_tensor(f"{prefix}.mixer.out_proj.bias")
186
+ )
187
+
188
+ # MLP
189
+ model.text["blocks"][i]["mlp"]["fc1"].weight.data.copy_(
190
+ get_tensor(f"{prefix}.mlp.fc1.weight")
191
+ )
192
+ model.text["blocks"][i]["mlp"]["fc1"].bias.data.copy_(
193
+ get_tensor(f"{prefix}.mlp.fc1.bias")
194
+ )
195
+ model.text["blocks"][i]["mlp"]["fc2"].weight.data.copy_(
196
+ get_tensor(f"{prefix}.mlp.fc2.weight")
197
+ )
198
+ model.text["blocks"][i]["mlp"]["fc2"].bias.data.copy_(
199
+ get_tensor(f"{prefix}.mlp.fc2.bias")
200
+ )
201
+
202
+ model.text["post_ln"].weight.data.copy_(get_tensor("text_model.lm_head.ln.weight"))
203
+ model.text["post_ln"].bias.data.copy_(get_tensor("text_model.lm_head.ln.bias"))
204
+
205
+ model.text["lm_head"].weight.data.copy_(
206
+ get_tensor("text_model.lm_head.linear.weight")
207
+ )
208
+ model.text["lm_head"].bias.data.copy_(get_tensor("text_model.lm_head.linear.bias"))
209
+
210
+ # Region Model
211
+ model.region.coord_features.data.copy_(
212
+ get_tensor("region_model.coordinate_features.weight").T
213
+ )
214
+ model.region["coord_encoder"].weight.data.copy_(
215
+ get_tensor("region_model.coordinate_encoder.weight")
216
+ )
217
+ model.region["coord_encoder"].bias.data.copy_(
218
+ get_tensor("region_model.coordinate_encoder.bias")
219
+ )
220
+
221
+ model.region["coord_decoder"]["fc1"].weight.data.copy_(
222
+ get_tensor("region_model.coordinate_decoder.fc1.weight")
223
+ )
224
+ model.region["coord_decoder"]["fc1"].bias.data.copy_(
225
+ get_tensor("region_model.coordinate_decoder.fc1.bias")
226
+ )
227
+ model.region["coord_decoder"]["fc2"].weight.data.copy_(
228
+ get_tensor("region_model.coordinate_decoder.fc2.weight")
229
+ )
230
+ model.region["coord_decoder"]["fc2"].bias.data.copy_(
231
+ get_tensor("region_model.coordinate_decoder.fc2.bias")
232
+ )
233
+
234
+ model.region.size_features.data.copy_(
235
+ get_tensor("region_model.size_features.weight").T
236
+ )
237
+ model.region["size_encoder"].weight.data.copy_(
238
+ get_tensor("region_model.size_encoder.weight")
239
+ )
240
+ model.region["size_encoder"].bias.data.copy_(
241
+ get_tensor("region_model.size_encoder.bias")
242
+ )
243
+
244
+ model.region["size_decoder"]["fc1"].weight.data.copy_(
245
+ get_tensor("region_model.size_decoder.fc1.weight")
246
+ )
247
+ model.region["size_decoder"]["fc1"].bias.data.copy_(
248
+ get_tensor("region_model.size_decoder.fc1.bias")
249
+ )
250
+ model.region["size_decoder"]["fc2"].weight.data.copy_(
251
+ get_tensor("region_model.size_decoder.fc2.weight")
252
+ )
253
+ model.region["size_decoder"]["fc2"].bias.data.copy_(
254
+ get_tensor("region_model.size_decoder.fc2.bias")
255
+ )
256
+
257
+
258
+ def load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None:
259
+ """Load weights from a safetensors file into a MoondreamModel instance."""
260
+ with safetensors_open(weights_file) as get_tensor:
261
+ # Wrap the get_tensor function to handle key normalization
262
+ name_map = {k.replace("._orig_mod", ""): k for k in get_tensor.keys()}
263
+ _load_weights(lambda x: get_tensor(name_map[x]).to(dtype=torch.float16), model)
264
+
265
+
266
+ def load_weights_from_pt(weights_file: str, model: nn.Module) -> None:
267
+ """Load weights from a PyTorch file into a MoondreamModel instance."""
268
+ device = str(torch.empty(0).device)
269
+ tensors = torch.load(weights_file, map_location=device, weights_only=True)
270
+ tensors = {
271
+ k.replace("._orig_mod", ""): v.to(dtype=torch.float16)
272
+ for k, v in tensors.items()
273
+ }
274
+ _load_weights(lambda x: tensors[x], model)
275
+
276
+
277
+ def load_weights_into_model(weights_file: str, model: nn.Module) -> None:
278
+ """
279
+ Load weights from either a safetensors or PyTorch file directly into a MoondreamModel instance.
280
+
281
+ Args:
282
+ weights_file: Path to weights file (either .safetensors or .pt)
283
+ model: MoondreamModel instance to load weights into
284
+ """
285
+ if weights_file.endswith(".safetensors"):
286
+ load_weights_from_safetensors(weights_file, model)
287
+ else:
288
+ load_weights_from_pt(weights_file, model)
289
+
290
+ # Make all parameters contiguous
291
+ for param in model.parameters():
292
+ param.data = param.data.contiguous()