-
Notifications
You must be signed in to change notification settings - Fork 321
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
58b2b58
commit c06207f
Showing
4 changed files
with
262,404 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import fire | ||
from deep_daze import Imagine | ||
from pathlib import Path | ||
|
||
def train(text, num_layers = 8): | ||
imagine = Imagine( | ||
text, | ||
num_layers = num_layers | ||
) | ||
|
||
if imagine.filename.exists(): | ||
answer = input('Imagined image already exists, do you want to overwrite?').lower() | ||
if answer not in ('yes', 'y'): | ||
exit() | ||
|
||
imagine() | ||
|
||
def main(): | ||
fire.Fire(train) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,239 @@ | ||
import torch | ||
from torch import nn | ||
import torch.nn.functional as F | ||
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize | ||
|
||
import hashlib | ||
import os | ||
import urllib | ||
import warnings | ||
|
||
from PIL import Image | ||
from tqdm import tqdm | ||
from pathlib import Path | ||
|
||
import html | ||
import os | ||
from functools import lru_cache | ||
from collections import OrderedDict | ||
|
||
import ftfy | ||
import regex as re | ||
|
||
MODEL_PATH = "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt" | ||
|
||
@lru_cache() | ||
def default_bpe(): | ||
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "data/bpe_simple_vocab_16e6.txt") | ||
|
||
@lru_cache() | ||
def bytes_to_unicode(): | ||
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) | ||
cs = bs[:] | ||
n = 0 | ||
for b in range(2**8): | ||
if b not in bs: | ||
bs.append(b) | ||
cs.append(2**8+n) | ||
n += 1 | ||
cs = [chr(n) for n in cs] | ||
return dict(zip(bs, cs)) | ||
|
||
def get_pairs(word): | ||
pairs = set() | ||
prev_char = word[0] | ||
for char in word[1:]: | ||
pairs.add((prev_char, char)) | ||
prev_char = char | ||
return pairs | ||
|
||
def basic_clean(text): | ||
text = ftfy.fix_text(text) | ||
text = html.unescape(html.unescape(text)) | ||
return text.strip() | ||
|
||
def whitespace_clean(text): | ||
text = re.sub(r'\s+', ' ', text) | ||
text = text.strip() | ||
return text | ||
|
||
class SimpleTokenizer(object): | ||
def __init__(self, bpe_path: str = default_bpe()): | ||
self.byte_encoder = bytes_to_unicode() | ||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} | ||
merges = Path(bpe_path).read_text().split('\n') | ||
merges = merges[1:49152-256-2+1] | ||
merges = [tuple(merge.split()) for merge in merges] | ||
vocab = list(bytes_to_unicode().values()) | ||
vocab = vocab + [v+'</w>' for v in vocab] | ||
for merge in merges: | ||
vocab.append(''.join(merge)) | ||
vocab.extend(['<|startoftext|>', '<|endoftext|>']) | ||
self.encoder = dict(zip(vocab, range(len(vocab)))) | ||
self.decoder = {v: k for k, v in self.encoder.items()} | ||
self.bpe_ranks = dict(zip(merges, range(len(merges)))) | ||
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} | ||
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) | ||
|
||
def bpe(self, token): | ||
if token in self.cache: | ||
return self.cache[token] | ||
word = tuple(token[:-1]) + ( token[-1] + '</w>',) | ||
pairs = get_pairs(word) | ||
|
||
if not pairs: | ||
return token+'</w>' | ||
|
||
while True: | ||
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) | ||
if bigram not in self.bpe_ranks: | ||
break | ||
first, second = bigram | ||
new_word = [] | ||
i = 0 | ||
while i < len(word): | ||
try: | ||
j = word.index(first, i) | ||
new_word.extend(word[i:j]) | ||
i = j | ||
except: | ||
new_word.extend(word[i:]) | ||
break | ||
|
||
if word[i] == first and i < len(word)-1 and word[i+1] == second: | ||
new_word.append(first+second) | ||
i += 2 | ||
else: | ||
new_word.append(word[i]) | ||
i += 1 | ||
new_word = tuple(new_word) | ||
word = new_word | ||
if len(word) == 1: | ||
break | ||
else: | ||
pairs = get_pairs(word) | ||
word = ' '.join(word) | ||
self.cache[token] = word | ||
return word | ||
|
||
def encode(self, text): | ||
bpe_tokens = [] | ||
text = whitespace_clean(basic_clean(text)).lower() | ||
for token in re.findall(self.pat, text): | ||
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) | ||
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) | ||
return bpe_tokens | ||
|
||
def decode(self, tokens): | ||
text = ''.join([self.decoder[token] for token in tokens]) | ||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ') | ||
return text | ||
|
||
def _download(url, root = os.path.expanduser("~/.cache/clip")): | ||
os.makedirs(root, exist_ok=True) | ||
filename = os.path.basename(url) | ||
|
||
expected_sha256 = url.split("/")[-2] | ||
download_target = os.path.join(root, filename) | ||
|
||
if os.path.exists(download_target) and not os.path.isfile(download_target): | ||
raise RuntimeError(f"{download_target} exists and is not a regular file") | ||
|
||
if os.path.isfile(download_target): | ||
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: | ||
return download_target | ||
else: | ||
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") | ||
|
||
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: | ||
with tqdm(total=int(source.info().get("Content-Length")), ncols=80) as loop: | ||
while True: | ||
buffer = source.read(8192) | ||
if not buffer: | ||
break | ||
|
||
output.write(buffer) | ||
loop.update(len(buffer)) | ||
|
||
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: | ||
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") | ||
|
||
return download_target | ||
|
||
normalize_image = Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) | ||
|
||
def load(device = ("cuda" if torch.cuda.is_available() else "cpu")): | ||
model_path = _download(MODEL_PATH) | ||
model = torch.jit.load(model_path, map_location = device).eval() | ||
n_px = model.input_resolution.item() | ||
|
||
transform = Compose([ | ||
Resize(n_px, interpolation=Image.BICUBIC), | ||
CenterCrop(n_px), | ||
lambda image: image.convert("RGB"), | ||
ToTensor(), | ||
normalize_image, | ||
]) | ||
|
||
# patch the device names | ||
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) | ||
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] | ||
|
||
def patch_device(module): | ||
graphs = [module.graph] if hasattr(module, "graph") else [] | ||
if hasattr(module, "forward1"): | ||
graphs.append(module.forward1.graph) | ||
|
||
for graph in graphs: | ||
for node in graph.findAllNodes("prim::Constant"): | ||
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): | ||
node.copyAttributes(device_node) | ||
|
||
model.apply(patch_device) | ||
patch_device(model.encode_image) | ||
patch_device(model.encode_text) | ||
|
||
# patch dtype to float32 on CPU | ||
if device == "cpu": | ||
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) | ||
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] | ||
float_node = float_input.node() | ||
|
||
def patch_float(module): | ||
graphs = [module.graph] if hasattr(module, "graph") else [] | ||
if hasattr(module, "forward1"): | ||
graphs.append(module.forward1.graph) | ||
|
||
for graph in graphs: | ||
for node in graph.findAllNodes("aten::to"): | ||
inputs = list(node.inputs()) | ||
for i in [1, 2]: # dtype can be the second or third argument to aten::to() | ||
if inputs[i].node()["value"] == 5: | ||
inputs[i].node().copyAttributes(float_node) | ||
|
||
model.apply(patch_float) | ||
patch_float(model.encode_image) | ||
patch_float(model.encode_text) | ||
|
||
model.float() | ||
|
||
return model, transform | ||
|
||
|
||
_tokenizer = SimpleTokenizer() | ||
|
||
def tokenize(texts, context_length: int = 77): | ||
if isinstance(texts, str): | ||
texts = [texts] | ||
|
||
sot_token = _tokenizer.encoder["<|startoftext|>"] | ||
eot_token = _tokenizer.encoder["<|endoftext|>"] | ||
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] | ||
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) | ||
|
||
for i, tokens in enumerate(all_tokens): | ||
if len(tokens) > context_length: | ||
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") | ||
result[i, :len(tokens)] = torch.tensor(tokens) | ||
|
||
return result |
Oops, something went wrong.