Update README.md
Browse files
README.md
CHANGED
|
@@ -85,7 +85,7 @@ processor = Pix2StructProcessor.from_pretrained("google/pix2struct-ai2d-base")
|
|
| 85 |
|
| 86 |
question = "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud"
|
| 87 |
|
| 88 |
-
inputs = processor(images=image, text=
|
| 89 |
|
| 90 |
predictions = model.generate(**inputs)
|
| 91 |
print(processor.decode(predictions[0], skip_special_tokens=True))
|
|
@@ -108,7 +108,7 @@ processor = Pix2StructProcessor.from_pretrained("google/pix2struct-ai2d-base")
|
|
| 108 |
|
| 109 |
question = "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud"
|
| 110 |
|
| 111 |
-
inputs = processor(images=image, text=
|
| 112 |
|
| 113 |
predictions = model.generate(**inputs)
|
| 114 |
print(processor.decode(predictions[0], skip_special_tokens=True))
|
|
@@ -133,7 +133,7 @@ processor = Pix2StructProcessor.from_pretrained("google/pix2struct-ai2d-base")
|
|
| 133 |
|
| 134 |
question = "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud"
|
| 135 |
|
| 136 |
-
inputs = processor(images=image, text=
|
| 137 |
|
| 138 |
predictions = model.generate(**inputs)
|
| 139 |
print(processor.decode(predictions[0], skip_special_tokens=True))
|
|
|
|
| 85 |
|
| 86 |
question = "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud"
|
| 87 |
|
| 88 |
+
inputs = processor(images=image, text=question, return_tensors="pt")
|
| 89 |
|
| 90 |
predictions = model.generate(**inputs)
|
| 91 |
print(processor.decode(predictions[0], skip_special_tokens=True))
|
|
|
|
| 108 |
|
| 109 |
question = "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud"
|
| 110 |
|
| 111 |
+
inputs = processor(images=image, text=question, return_tensors="pt").to("cuda")
|
| 112 |
|
| 113 |
predictions = model.generate(**inputs)
|
| 114 |
print(processor.decode(predictions[0], skip_special_tokens=True))
|
|
|
|
| 133 |
|
| 134 |
question = "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud"
|
| 135 |
|
| 136 |
+
inputs = processor(images=image, text=question, return_tensors="pt").to("cuda", torch.bfloat16)
|
| 137 |
|
| 138 |
predictions = model.generate(**inputs)
|
| 139 |
print(processor.decode(predictions[0], skip_special_tokens=True))
|