Commit 6c1a2d67 authored by Eren Doğan's avatar Eren Doğan Committed by GitHub

Merge pull request #9 from NovelAI/os.changes

parents 91470a5a 9fc1cc21
import numpy as np
import torch
import mmap
import pickle
import concurrent
from torch.utils import data
from simplejpeg import decode_jpeg
import simplejpeg
import pickle
from pathlib import Path
from PIL import Image
from tqdm import tqdm
from concurrent.futures import as_completed
import requests
......@@ -54,6 +50,9 @@ class ShardedDataset(data.Dataset):
class ShardedImageDataset(data.Dataset):
def __init__(self, dataset_path: str, index_path: str, metadata_path=None, threads=None, inner_transform=None,
outer_transform=None, skip=0, bsz=256, world_size=1, local_rank=0, global_rank=0, device="cpu"):
from simplejpeg import decode_jpeg
import simplejpeg
from PIL import Image
self.skip = skip
self.threads = threads
......
......@@ -87,11 +87,13 @@ def load_from_path(config_folder=None, strict=False):
model = _load_dict_model(model_class, model_config, model_path, strict=strict)
return model
def _load_dict_model(model_class, config, path=None, state_dict=None, strict=False):
def _load_dict_model(model_class, config, path=None, state_dict=None,
strict=False, device="cuda"):
# I am kinda sad that we will not have a load function in lm object itself.
# might be better to add load functions -- actually nope.
if path:
state_dict = utils.SplitCheckpoint(path, device="cuda")
state_dict = utils.SplitCheckpoint(path, device=device)
state_dict.device = device
model= utils.no_init(lambda: model_class(config))
model.load_state_dict(state_dict, strict=strict)
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment