.. _sec_utils: Utility Functions and Classes ============================= This section contains the implementations of utility functions and classes used in this book. .. raw:: latex \diilbookstyleinputcell .. code:: python import collections import inspect import mlx import mlx.core as mx import mlx.nn as nn from IPython import display from d2l import mlx as d2l .. raw:: latex \diilbookstyleinputcell .. code:: python class HyperParameters: #@save """The base class of hyperparameters.""" def save_hyperparameters(self, ignore=[]): """Defined in :numref:`sec_oo-design`""" raise NotImplemented def save_hyperparameters(self, ignore=[]): """Save function arguments into class attributes. Defined in :numref:`sec_utils`""" frame = inspect.currentframe().f_back _, _, _, local_vars = inspect.getargvalues(frame) self.hparams = {k:v for k, v in local_vars.items() if k not in set(ignore+['self']) and not k.startswith('_')} for k, v in self.hparams.items(): setattr(self, k, v) class DataModule(d2l.HyperParameters): #@save """The base class of data. Defined in :numref:`subsec_oo-design-models`""" def __init__(self, root='../data'): self.save_hyperparameters() def get_dataloader(self, train): raise NotImplementedError def train_dataloader(self): return self.get_dataloader(train=True) def val_dataloader(self): return self.get_dataloader(train=False) def get_tensorloader(self, tensors, train, indices=slice(0, None)): """Defined in :numref:`sec_synthetic-regression-data`""" tensors = tuple(a[indices] for a in tensors) dataset = d2l.Dataset(*tensors) return d2l.DataLoader(dataset, self.batch_size, shuffle=train) class ProgressBoard(d2l.HyperParameters): #@save """The board that plots data points in animation. Defined in :numref:`sec_oo-design`""" def __init__(self, xlabel=None, ylabel=None, xlim=None, ylim=None, xscale='linear', yscale='linear', ls=['-', '--', '-.', ':'], colors=['C0', 'C1', 'C2', 'C3'], fig=None, axes=None, figsize=(3.5, 2.5), display=True): self.save_hyperparameters() def draw(self, x, y, label, every_n=1): raise NotImplemented def draw(self, x, y, label, every_n=1): """Defined in :numref:`sec_utils`""" Point = collections.namedtuple('Point', ['x', 'y']) if not hasattr(self, 'raw_points'): self.raw_points = collections.OrderedDict() self.data = collections.OrderedDict() if label not in self.raw_points: self.raw_points[label] = [] self.data[label] = [] points = self.raw_points[label] line = self.data[label] points.append(Point(x, y)) if len(points) != every_n: return mean = lambda x: sum(x) / len(x) line.append(Point(mean([p.x for p in points]), mean([p.y for p in points]))) points.clear() if not self.display: return d2l.use_svg_display() if self.fig is None: self.fig = d2l.plt.figure(figsize=self.figsize) plt_lines, labels = [], [] for (k, v), ls, color in zip(self.data.items(), self.ls, self.colors): plt_lines.append(d2l.plt.plot([p.x for p in v], [p.y for p in v], linestyle=ls, color=color)[0]) labels.append(k) axes = self.axes if self.axes else d2l.plt.gca() if self.xlim: axes.set_xlim(self.xlim) if self.ylim: axes.set_ylim(self.ylim) if not self.xlabel: self.xlabel = self.x axes.set_xlabel(self.xlabel) axes.set_ylabel(self.ylabel) axes.set_xscale(self.xscale) axes.set_yscale(self.yscale) axes.legend(plt_lines, labels) display.display(self.fig) display.clear_output(wait=True) class Module(nn.Module, d2l.HyperParameters): #@save """The base class of models. Defined in :numref:`sec_oo-design`""" def __init__(self, plot_train_per_epoch=2, plot_valid_per_epoch=1): super().__init__() self.save_hyperparameters() self.board = ProgressBoard() def loss(self, y_hat, y): raise NotImplementedError def __call__(self, X): assert hasattr(self, 'net'), 'Neural network is defined' return self.net(X) def plot(self, key, value, train): """Plot a point in animation.""" assert hasattr(self, 'trainer'), 'Trainer is not inited' self.board.xlabel = 'epoch' if train: x = self.trainer.train_batch_idx / \ self.trainer.num_train_batches n = self.trainer.num_train_batches / \ self.plot_train_per_epoch else: x = self.trainer.epoch + 1 n = self.trainer.num_val_batches / \ self.plot_valid_per_epoch self.board.draw(x, np.array(value), ('train_' if train else 'val_') + key, every_n=int(n)) def training_step(self, batch): def loss_fn(X, y): return self.loss(self(*X), y) l, grad = nn.value_and_grad(self, loss_fn)((batch[:-1]), batch[-1]) self.plot('loss', l.item(), train=True) return l, grad def validation_step(self, batch): l = self.loss(self(*batch[:-1]), batch[-1]) self.plot('loss', l.item(), train=False) def configure_optimizers(self): raise NotImplementedError def configure_optimizers(self): """Defined in :numref:`sec_classification`""" return optim.SGD(learning_rate=self.lr) class RNNScratch(d2l.Module): #@save """The RNN model implemented from scratch. Defined in :numref:`sec_rnn-scratch`""" def __init__(self, num_inputs, num_hiddens, sigma=0.01): super().__init__() self.save_hyperparameters() self.W_xh = mx.random.normal(shape=(num_inputs, num_hiddens)) * sigma self.W_hh = mx.random.normal(shape=(num_hiddens, num_hiddens)) * sigma self.b_h = mx.zeros(num_hiddens) def forward(self, inputs, state=None): """Defined in :numref:`sec_rnn-scratch`""" if state is None: # Initial state with shape: (batch_size, num_hiddens) state = mx.zeros((inputs.shape[1], self.num_hiddens)) else: state, = state outputs = [] for X in inputs: # Shape of inputs: (num_steps, batch_size, num_inputs) state = mx.tanh(mx.matmul(X, self.W_xh) + mx.matmul(state, self.W_hh) + self.b_h) outputs.append(state) return outputs, state def __call__(self, inputs, state=None): """Defined in :numref:`sec_rnn-scratch`""" if state is None: # Initial state with shape: (batch_size, num_hiddens) state = mx.zeros((inputs.shape[1], self.num_hiddens)) else: state, = state outputs = [] for X in inputs: # Shape of inputs: (num_steps, batch_size, num_inputs) state = mx.tanh(mx.matmul(X, self.W_xh) + mx.matmul(state, self.W_hh) + self.b_h) outputs.append(state) return outputs, state def check_len(a, n): #@save """Check the length of a list. Defined in :numref:`sec_rnn-scratch`""" assert len(a) == n, f'list\'s length {len(a)} != expected length {n}' def check_shape(a, shape): #@save """Check the shape of a tensor. Defined in :numref:`sec_rnn-scratch`""" assert a.shape == shape, \ f'tensor\'s shape {a.shape} != expected shape {shape}' class RNN(d2l.Module): #@save """The RNN model implemented with high-level APIs. Defined in :numref:`sec_rnn-concise`""" def __init__(self, num_inputs, num_hiddens): super().__init__() self.save_hyperparameters() self.rnn = d2l.RNNScratch(num_inputs, num_hiddens) def __call__(self, inputs, H=None): return self.rnn(inputs, H) class GRUScratch(d2l.Module): #@save """Defined in :numref:`sec_gru`""" def __init__(self, num_inputs, num_hiddens, sigma=0.01, dropout=0): super().__init__() self.save_hyperparameters() init_weight = lambda *shape: mx.random.normal(*shape) * sigma triple = lambda: (init_weight((num_inputs, num_hiddens)), init_weight((num_hiddens, num_hiddens)), mx.zeros(shape=(num_hiddens,))) self.W_xz, self.W_hz, self.b_z = triple() # Update gate self.W_xr, self.W_hr, self.b_r = triple() # Reset gate self.W_xh, self.W_hh, self.b_h = triple() # Candidate hidden state def __call__(self, inputs, H=None): """Defined in :numref:`sec_gru`""" if H is None: # Initial state with shape: (batch_size, num_hiddens) H = mx.zeros((inputs.shape[1], self.num_hiddens)) outputs = [] for X in inputs: Z = mx.sigmoid(mx.matmul(X, self.W_xz) + mx.matmul(H, self.W_hz) + self.b_z) R = mx.sigmoid(mx.matmul(X, self.W_xr) + mx.matmul(H, self.W_hr) + self.b_r) H_tilde = mx.tanh(mx.matmul(X, self.W_xh) + mx.matmul(R * H, self.W_hh) + self.b_h) H = Z * H + (1 - Z) * H_tilde H = d2l.dropout_layer(H, self.dropout) outputs.append(H) return outputs, H class StackedGRUScratch(d2l.Module): #@save """Defined in :numref:`sec_deep_rnn`""" def __init__(self, num_inputs, num_hiddens, num_layers, sigma=0.01, dropout=0): super().__init__() self.save_hyperparameters() self.grus = nn.Sequential(*[d2l.GRUScratch( num_inputs if i==0 else num_hiddens, num_hiddens, sigma) for i in range(num_layers)]) def __call__(self, inputs, Hs=None): outputs = inputs if Hs is None: Hs = [None] * self.num_layers for i in range(self.num_layers): outputs, Hs[i] = self.grus.layers[i](outputs, Hs[i]) outputs = mx.stack(outputs, axis=0) return outputs, Hs class GRU(d2l.RNN): #@save """The multilayer GRU model. Defined in :numref:`sec_deep_rnn`""" def __init__(self, num_inputs, num_hiddens, num_layers, dropout=0): d2l.Module.__init__(self) self.save_hyperparameters() self.rnn = d2l.StackedGRUScratch(num_inputs, num_hiddens, num_layers, sigma=0.01, dropout=dropout) def dropout_layer(X, dropout): #@save """Defined in :numref:`sec_dropout`""" assert 0 <= dropout <= 1 if dropout == 1: return mx.zeros_like(X) mask = (mx.random.uniform(shape=X.shape) > dropout).astype(mx.float32) return mask * X / (1.0 - dropout) def add_to_class(Class): #@save """Register functions as methods in created class. Defined in :numref:`sec_oo-design`""" def wrapper(obj): setattr(Class, obj.__name__, obj) return wrapper def num_gpus(): #@save """Get the number of available GPUs. Defined in :numref:`sec_use_gpu`""" return 1 def gpu(i=0): #@save """Get a GPU device. Defined in :numref:`sec_use_gpu`""" mx.set_default_device(mx.gpu) return mx.default_device() def try_gpu(i=0): #@save """Return gpu(i) if exists, otherwise return cpu(). Defined in :numref:`sec_use_gpu`""" return gpu(0) class Trainer(d2l.HyperParameters): #@save """The base class for training models with data. Defined in :numref:`subsec_oo-design-models`""" def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0): self.save_hyperparameters() assert num_gpus == 0, 'No GPU support yet' def prepare_data(self, data): self.train_dataloader = data.train_dataloader() self.val_dataloader = data.val_dataloader() self.num_train_batches = len(self.train_dataloader) self.num_val_batches = (len(self.val_dataloader) if self.val_dataloader is not None else 0) def prepare_model(self, model): model.trainer = self model.board.xlim = [0, self.max_epochs] self.model = model def fit(self, model, data): self.prepare_data(data) self.prepare_model(model) self.optim = model.configure_optimizers() self.epoch = 0 self.train_batch_idx = 0 self.val_batch_idx = 0 for self.epoch in range(self.max_epochs): self.fit_epoch() def fit_epoch(self): raise NotImplementedError def prepare_batch(self, batch): """Defined in :numref:`sec_linear_scratch`""" return batch def fit_epoch(self): """Defined in :numref:`sec_linear_scratch`""" self.model.train(True) for batch in self.train_dataloader: loss, grads = self.model.training_step(self.prepare_batch(batch)) if self.gradient_clip_val > 0: grads = self.clip_gradients(self.gradient_clip_val, grads) self.optim.update(model=self.model, gradients=grads) mx.eval(self.model.parameters()) self.train_batch_idx += 1 if self.val_dataloader is None: return self.model.eval() for batch in self.val_dataloader: self.model.validation_step(self.prepare_batch(batch)) self.val_batch_idx += 1 def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0): """Defined in :numref:`sec_use_gpu`""" self.save_hyperparameters() self.gpus = [d2l.gpu(i) for i in range(min(num_gpus, d2l.num_gpus()))] def prepare_batch(self, batch): """Defined in :numref:`sec_use_gpu`""" if self.gpus: gpu() batch = [a for a in batch] return batch def prepare_model(self, model): """Defined in :numref:`sec_use_gpu`""" model.trainer = self model.board.xlim = [0, self.max_epochs] if self.gpus: gpu() self.model = model def clip_gradients(self, grad_clip_val, grads): """Defined in :numref:`sec_rnn-scratch` Defined in :numref:`sec_rnn-scratch`""" grad_leaves = mlx.utils.tree_flatten(grads) norm = mx.sqrt(sum((x[1] ** 2).sum() for x in grad_leaves)) clip = lambda grad: mx.where(norm < grad_clip_val, grad, grad * (grad_clip_val / norm)) return mlx.utils.tree_map(clip, grads) def download_new(url, folder='../data', sha1_hash=None): #@save """Download a file to folder and return the local filepath. Defined in :numref:`sec_utils`""" if not url.startswith('http'): # For back compatability url, sha1_hash = DATA_HUB[url] os.makedirs(folder, exist_ok=True) fname = os.path.join(folder, url.split('/')[-1]) # Check if hit cache if os.path.exists(fname) and sha1_hash: sha1 = hashlib.sha1() with open(fname, 'rb') as f: while True: data = f.read(1048576) if not data: break sha1.update(data) if sha1.hexdigest() == sha1_hash: return fname # Download print(f'Downloading {fname} from {url}...') r = requests.get(url, stream=True, verify=True) with open(fname, 'wb') as f: f.write(r.content) return fname def extract(filename, folder=None): #@save """Extract a zip/tar file into folder. Defined in :numref:`sec_utils`""" base_dir = os.path.dirname(filename) _, ext = os.path.splitext(filename) assert ext in ('.zip', '.tar', '.gz'), 'Only support zip/tar files.' if ext == '.zip': fp = zipfile.ZipFile(filename, 'r') else: fp = tarfile.open(filename, 'r') if folder is None: folder = base_dir fp.extractall(folder) class Classifier(d2l.Module): #@save """The base class of classification models. Defined in :numref:`sec_classification`""" def validation_step(self, batch): Y_hat = self(*batch[:-1]) self.plot('loss', self.loss(Y_hat, batch[-1]), train=False) self.plot('acc', self.accuracy(Y_hat, batch[-1]), train=False) def accuracy(self, Y_hat, Y, averaged=True): """Compute the number of correct predictions. Defined in :numref:`sec_classification`""" Y_hat = Y_hat.reshape((-1, Y_hat.shape[-1])) preds = Y_hat.argmax(axis=1).astype(Y.dtype) compare = (preds == Y.reshape(-1)).astype(mx.float32) return compare.mean() if averaged else compare def loss(self, Y_hat, Y, averaged=True): """Defined in :numref:`sec_softmax_concise`""" Y_hat = Y_hat.reshape((-1, Y_hat.shape[-1])) Y = Y.reshape((-1,)) return nn.losses.cross_entropy( Y_hat, Y, reduction='mean' if averaged else 'none') def layer_summary(self, X_shape): """Defined in :numref:`sec_lenet`""" X = mx.random.normal(shape=(X_shape)) for layer in self.net.layers: X = layer(X) print(layer.__class__.__name__, 'output shape:\t', X.shape) class MTFraEng(d2l.DataModule): #@save """The English-French dataset. Defined in :numref:`sec_machine_translation`""" def _download_new(self): d2l.extract(d2l.download_new( d2l.DATA_URL+'fra-eng.zip', self.root, '94646ad1522d915e7b0f9296181140edcf86a4f5')) with open(self.root + '/fra-eng/fra.txt', encoding='utf-8') as f: return f.read() def _preprocess(self, text): """Defined in :numref:`sec_machine_translation`""" # Replace non-breaking space with space text = text.replace('\u202f', ' ').replace('\xa0', ' ') # Insert space between words and punctuation marks no_space = lambda char, prev_char: char in ',.!?' and prev_char != ' ' out = [' ' + char if i > 0 and no_space(char, text[i - 1]) else char for i, char in enumerate(text.lower())] return ''.join(out) def _tokenize(self, text, max_examples=None): """Defined in :numref:`sec_machine_translation`""" src, tgt = [], [] for i, line in enumerate(text.split('\n')): if max_examples and i > max_examples: break parts = line.split('\t') if len(parts) == 2: # Skip empty tokens src.append([t for t in f'{parts[0]} '.split(' ') if t]) tgt.append([t for t in f'{parts[1]} '.split(' ') if t]) return src, tgt def __init__(self, batch_size, num_steps=9, num_train=512, num_val=128): """Defined in :numref:`sec_machine_translation`""" super(MTFraEng, self).__init__() self.save_hyperparameters() self.arrays, self.src_vocab, self.tgt_vocab = self._build_arrays( self._download_new()) def __init__(self, batch_size, num_steps=9, num_train=512, num_val=128): """Defined in :numref:`sec_machine_translation`""" super(MTFraEng, self).__init__() self.save_hyperparameters() self.arrays, self.src_vocab, self.tgt_vocab = self._build_arrays( self._download_new()) def _build_arrays(self, raw_text, src_vocab=None, tgt_vocab=None): """Defined in :numref:`subsec_mt_data_loading`""" def _build_array(sentences, vocab, is_tgt=False): pad_or_trim = lambda seq, t: ( seq[:t] if len(seq) > t else seq + [''] * (t - len(seq))) sentences = [pad_or_trim(s, self.num_steps) for s in sentences] if is_tgt: sentences = [[''] + s for s in sentences] if vocab is None: vocab = d2l.Vocab(sentences, min_freq=2) array = mx.array([vocab[s] for s in sentences]) valid_len = (array != vocab['']).astype(mx.int32).sum(1) return array, vocab, valid_len src, tgt = self._tokenize(self._preprocess(raw_text), self.num_train + self.num_val) src_array, src_vocab, src_valid_len = _build_array(src, src_vocab) tgt_array, tgt_vocab, _ = _build_array(tgt, tgt_vocab, True) return ((src_array, tgt_array[:,:-1], src_valid_len, tgt_array[:,1:]), src_vocab, tgt_vocab) def get_dataloader(self, train): """Defined in :numref:`subsec_mt_data_loading`""" idx = slice(0, self.num_train) if train else slice(self.num_train, None) return self.get_tensorloader(self.arrays, train, idx) def build(self, src_sentences, tgt_sentences): """Defined in :numref:`subsec_mt_data_loading`""" raw_text = '\n'.join([src + '\t' + tgt for src, tgt in zip( src_sentences, tgt_sentences)]) arrays, _, _ = self._build_arrays( raw_text, self.src_vocab, self.tgt_vocab) return arrays class Encoder(d2l.Module): #@save """编码器-解码器架构的基本编码器接口""" def __init__(self): super().__init__() def __call__(self, X, *args): raise NotImplementedError class Decoder(d2l.Module): #@save """编码器-解码器架构的基本解码器接口 Defined in :numref:`sec_encoder-decoder`""" def __init__(self): super().__init__() def init_state(self, enc_all_outputs, *args): raise NotImplementedError def __call__(self, X, state): raise NotImplementedError class EncoderDecoder(d2l.Classifier): #@save """编码器-解码器架构的基类 Defined in :numref:`sec_encoder-decoder`""" def __init__(self, encoder, decoder): super().__init__() self.encoder = encoder self.decoder = decoder def __call__(self, enc_X, dec_X, *args): enc_all_outputs = self.encoder(enc_X, *args) dec_state = self.decoder.init_state(enc_all_outputs, *args) return self.decoder(dec_X, dec_state)[0] def predict_step(self, batch, device, num_steps, save_attention_weights=False): """Defined in :numref:`sec_seq2seq_decoder`""" src, tgt, src_valid_len, _ = batch enc_all_outputs = self.encoder(src, src_valid_len) dec_state = self.decoder.init_state(enc_all_outputs, src_valid_len) outputs, attention_weights = [tgt[:, (0)].reshape(-1 ,1), ], [] for _ in range(num_steps): Y, dec_state = self.decoder(outputs[-1], dec_state) outputs.append(Y.argmax(2)) # Save attention weights (to be covered later) if save_attention_weights: attention_weights.append(self.decoder.attention_weights) return mx.concatenate(outputs[1:], 1), attention_weights def init_seq2seq(array): #@save """Defined in :numref:`sec_seq2seq`""" if array.ndim > 1: weight_fn = nn.init.glorot_uniform() array = weight_fn(array) return array class Seq2SeqEncoder(d2l.Encoder): #@save """用于序列到序列学习的循环神经网络编码器 Defined in :numref:`sec_seq2seq`""" def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0): super().__init__() self.embedding = nn.Embedding(vocab_size, embed_size) self.rnn = d2l.GRU(embed_size, num_hiddens, num_layers, dropout) # print("before init:", tree_flatten(self.parameters())) for module in self.modules(): if isinstance(module, nn.Linear) or isinstance(module, d2l.GRU): module.update(mlx.utils.tree_map(lambda x: init_seq2seq(x), module.parameters())) def __call__(self, X, *args): # X shape: (batch_size, num_steps) embs = self.embedding(X.T.astype(mx.int64)) # embs shape: (num_steps, batch_size, embed_size) outputs, state = self.rnn(embs) state = mx.array(state) # outputs shape: (num_steps, batch_size, num_hiddens) # state shape: (num_layers, batch_size, num_hiddens) return outputs, state class Seq2SeqDecoder(d2l.Decoder): #@save """The RNN decoder for sequence to sequence learning. Defined in :numref:`sec_seq2seq`""" def __init__(self, num_inputs, vocab_size, embed_size, num_hiddens, num_layers, dropout=0): super().__init__() self.embedding = nn.Embedding(vocab_size, embed_size) self.rnn = d2l.GRU(embed_size+num_hiddens, num_hiddens, num_layers, dropout) self.dense = nn.Linear(num_inputs, vocab_size) # print("before init:", tree_flatten(self.parameters())) for module in self.modules(): if isinstance(module, nn.Linear) or isinstance(module, d2l.GRU): module.update(mlx.utils.tree_map(lambda x: init_seq2seq(x), module.parameters())) def init_state(self, enc_all_outputs, *args): return enc_all_outputs def __call__(self, X, state): # X shape: (batch_size, num_steps) # embs shape: (num_steps, batch_size, embed_size) embs = self.embedding(X.T.astype(mx.int32)) enc_output, hidden_state = state # context shape: (batch_size, num_hiddens) context = enc_output[-1] # Broadcast context to (num_steps, batch_size, num_hiddens) context = mx.tile(context, (embs.shape[0], 1, 1)) # Concat at the feature dimension embs_and_context = mx.concatenate((embs, context), -1) outputs, hidden_state = self.rnn(embs_and_context, hidden_state) outputs = self.dense(outputs).swapaxes(0, 1) # outputs shape: (batch_size, num_steps, vocab_size) # hidden_state shape: (num_layers, batch_size, num_hiddens) return outputs, [enc_output, hidden_state] class Seq2Seq(d2l.EncoderDecoder): #@save """The RNN encoder--decoder for sequence to sequence learning. Defined in :numref:`sec_seq2seq_decoder`""" def __init__(self, encoder, decoder, tgt_pad, lr): super().__init__(encoder, decoder) self.save_hyperparameters() def validation_step(self, batch): Y_hat = self(*batch[:-1]) self.plot('loss', self.loss(Y_hat, batch[-1]), train=False) def configure_optimizers(self): # Adam optimizer is used here return optim.Adam(learning_rate=self.lr) def loss(self, Y_hat, Y): """Defined in :numref:`sec_seq2seq_decoder`""" l = super(Seq2Seq, self).loss(Y_hat, Y, averaged=False) mask = (Y.reshape(-1) != self.tgt_pad).astype(mx.float32) return (l * mask).sum() / mask.sum() def bleu(pred_seq, label_seq, k): #@save """计算BLEU Defined in :numref:`sec_seq2seq_decoder`""" pred_tokens, label_tokens = pred_seq.split(' '), label_seq.split(' ') len_pred, len_label = len(pred_tokens), len(label_tokens) score = math.exp(min(0, 1 - len_label / len_pred)) for n in range(1, min(k, len_pred) + 1): num_matches, label_subs = 0, collections.defaultdict(int) for i in range(len_label - n + 1): label_subs[' '.join(label_tokens[i: i + n])] += 1 for i in range(len_pred - n + 1): if label_subs[' '.join(pred_tokens[i: i + n])] > 0: num_matches += 1 label_subs[' '.join(pred_tokens[i: i + n])] -= 1 score *= math.pow(num_matches / (len_pred - n + 1), math.pow(0.5, n)) return score class Dataset: #@save def __init__(self, *tensors): """Defined in :numref:`sec_fashion_mnist`""" assert all( tensors[0].shape[0] == tensor.shape[0] for tensor in tensors ), "Size mismatch between tensors" self.tensors = tensors def __getitem__(self, index): return tuple(mx.array(tensor[index]) for tensor in self.tensors) def __len__(self): return self.tensors[0].shape[0] class DataLoader: #@save def __init__(self, dataset, batch_size=1, shuffle=False, drop_last=False): """Defined in :numref:`sec_fashion_mnist`""" self.dataset = dataset self.batch_size = batch_size self.shuffle = shuffle self.drop_last = drop_last self.indices = list(range(len(dataset))) self.current_index = 0 if self.shuffle: random.shuffle(self.indices) def __iter__(self): self.current_index = 0 if self.shuffle: random.shuffle(self.indices) return self def __next__(self): if self.current_index >= len(self.indices): raise StopIteration end_index = self.current_index + self.batch_size if end_index > len(self.indices): if self.drop_last: raise StopIteration else: end_index = len(self.indices) batch_indices = self.indices[self.current_index:end_index] batch = [self.dataset[i] for i in batch_indices] self.current_index = end_index return self.collate_fn(batch) def __len__(self): if self.drop_last: return len(self.dataset) // self.batch_size else: return math.ceil(len(self.dataset) / self.batch_size) def collate_fn(self, batch): if isinstance(batch[0], tuple): if (len(batch[0])) == 2: data, targets = zip(*batch) data = mx.array(data) targets = mx.array(targets) return data, targets if (len(batch[0])) == 4: data, decoder_input, src_valid_len, targets = zip(*batch) data = mx.array(data) decoder_input = mx.array(decoder_input) src_valid_len = mx.array(src_valid_len) targets = mx.array(targets) return data, decoder_input, src_valid_len, targets return mx.array(batch) class AdaptiveAvgPool2d(nn.Module): #@save """Applies a 2D adaptive average pooling over an input signal. The output spatial dimensions are specified by `output_size`. This implementation uses a standard `mlx.nn.AvgPool2d` layer with dynamically calculated kernel size and stride to achieve the target output size. It is primarily designed for downsampling (input size >= output size). Args: output_size (int or tuple): The target output size of the image of the form H x W. Can be a single integer H to specify H x H, or a tuple of two integers (H, W). """ def __init__(self, output_size): super().__init__() if isinstance(output_size, int): self.output_h = output_size self.output_w = output_size elif isinstance(output_size, tuple) and len(output_size) == 2 and \ isinstance(output_size[0], int) and isinstance(output_size[1], int): self.output_h = output_size[0] self.output_w = output_size[1] else: raise ValueError( "output_size must be an int or a tuple of two positive ints" ) if self.output_h <= 0 or self.output_w <= 0: raise ValueError("output_size dimensions must be positive") def __call__(self, x: mx.array) -> mx.array: """ Forward pass for AdaptiveAvgPool2d. Args: x (mx.array): Input tensor of shape (N, H_in, W_in, C). MLX typically uses channels-last format. Returns: mx.array: Output tensor of shape (N, H_out, W_out, C). """ input_h = x.shape[1] input_w = x.shape[2] if self.output_h == input_h and self.output_w == input_w: return x # Handle the common Global Average Pooling case efficiently if self.output_h == 1 and self.output_w == 1: return mx.mean(x, axis=(1, 2), keepdims=True) # This implementation relies on nn.AvgPool2d and is suited for downsampling. if input_h < self.output_h or input_w < self.output_w: raise ValueError( f"Input spatial size ({input_h}x{input_w}) is smaller than target output size " f"({self.output_h}x{self.output_w}) in at least one dimension. " "This AdaptiveAvgPool2d implementation using nn.AvgPool2d " "requires input dimensions to be greater than or equal to output dimensions." ) # Calculate stride # stride_h = math.floor(input_h / self.output_h) # floor is implicit with // for positive stride_h = input_h // self.output_h stride_w = input_w // self.output_w # Calculate kernel size using the formula: K = I - (O - 1) * S # This ensures that an AvgPool2d operation with this kernel and stride # will produce an output of the target size O. kernel_h = input_h - (self.output_h - 1) * stride_h kernel_w = input_w - (self.output_w - 1) * stride_w # Ensure kernel and stride are valid (should be if input_h >= output_h etc.) if kernel_h <= 0 or kernel_w <= 0 or stride_h <= 0 or stride_w <= 0: # This case should ideally be caught by the input_h < self.output_h checks, # or if output_h/output_w are non-positive (checked in __init__). raise RuntimeError( f"Calculated invalid pooling parameters: " f"kernel=({kernel_h},{kernel_w}), stride=({stride_h},{stride_w}). " f"Input: ({input_h},{input_w}), Output: ({self.output_h},{self.output_w})" ) pool_layer = nn.AvgPool2d( kernel_size=(kernel_h, kernel_w), stride=(stride_h, stride_w), padding=0 # Adaptive pooling typically does not involve explicit padding ) return pool_layer(x) class resnet18(nn.Module): #@save """稍加修改的ResNet-18模型 Defined in :numref:`sec_multi_gpu_concise`""" def __init__(self, num_classes, in_channels=1): super(resnet18, self).__init__() def resnet_block(in_channels, out_channels, num_residuals, first_block=False): blk = [] for i in range(num_residuals): if i == 0 and not first_block: blk.append(d2l.Residual(in_channels, out_channels, use_1x1conv=True, strides=2)) else: blk.append(d2l.Residual(out_channels, out_channels)) #return nn.Sequential(*blk) return blk b1 = nn.Sequential(nn.Conv2d(in_channels, 64, kernel_size=7, stride=2,padding=3), nn.BatchNorm(64), nn.ReLU(), nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True)) b3 = nn.Sequential(*resnet_block(64, 128, 2)) b4 = nn.Sequential(*resnet_block(128, 256, 2)) b5 = nn.Sequential(*resnet_block(256, 512, 2)) self.layers = nn.Sequential(b1, b2, b3, b4, b5, # nn.AvgPool2d((3,3)), d2l.AdaptiveAvgPool2d((1, 1)), d2l.Flatten(), nn.Linear(512, num_classes)) def __call__(self, x): # 该模型使用了更小的卷积核、步长和填充,而且删除了最大汇聚层 x = self.layers(x) return x