1.6. Utility Functions and Classes¶
This section contains the implementations of utility functions and classes used in this book.
import collections
import inspect
import mlx
import mlx.core as mx
import mlx.nn as nn
from IPython import display
from d2l import mlx as d2l
class HyperParameters: #@save
"""The base class of hyperparameters."""
def save_hyperparameters(self, ignore=[]):
"""Defined in :numref:`sec_oo-design`"""
raise NotImplemented
def save_hyperparameters(self, ignore=[]):
"""Save function arguments into class attributes.
Defined in :numref:`sec_utils`"""
frame = inspect.currentframe().f_back
_, _, _, local_vars = inspect.getargvalues(frame)
self.hparams = {k:v for k, v in local_vars.items()
if k not in set(ignore+['self']) and not k.startswith('_')}
for k, v in self.hparams.items():
setattr(self, k, v)
class DataModule(d2l.HyperParameters): #@save
"""The base class of data.
Defined in :numref:`subsec_oo-design-models`"""
def __init__(self, root='../data'):
self.save_hyperparameters()
def get_dataloader(self, train):
raise NotImplementedError
def train_dataloader(self):
return self.get_dataloader(train=True)
def val_dataloader(self):
return self.get_dataloader(train=False)
def get_tensorloader(self, tensors, train, indices=slice(0, None)):
"""Defined in :numref:`sec_synthetic-regression-data`"""
tensors = tuple(a[indices] for a in tensors)
dataset = d2l.Dataset(*tensors)
return d2l.DataLoader(dataset, self.batch_size, shuffle=train)
class ProgressBoard(d2l.HyperParameters): #@save
"""The board that plots data points in animation.
Defined in :numref:`sec_oo-design`"""
def __init__(self, xlabel=None, ylabel=None, xlim=None,
ylim=None, xscale='linear', yscale='linear',
ls=['-', '--', '-.', ':'], colors=['C0', 'C1', 'C2', 'C3'],
fig=None, axes=None, figsize=(3.5, 2.5), display=True):
self.save_hyperparameters()
def draw(self, x, y, label, every_n=1):
raise NotImplemented
def draw(self, x, y, label, every_n=1):
"""Defined in :numref:`sec_utils`"""
Point = collections.namedtuple('Point', ['x', 'y'])
if not hasattr(self, 'raw_points'):
self.raw_points = collections.OrderedDict()
self.data = collections.OrderedDict()
if label not in self.raw_points:
self.raw_points[label] = []
self.data[label] = []
points = self.raw_points[label]
line = self.data[label]
points.append(Point(x, y))
if len(points) != every_n:
return
mean = lambda x: sum(x) / len(x)
line.append(Point(mean([p.x for p in points]),
mean([p.y for p in points])))
points.clear()
if not self.display:
return
d2l.use_svg_display()
if self.fig is None:
self.fig = d2l.plt.figure(figsize=self.figsize)
plt_lines, labels = [], []
for (k, v), ls, color in zip(self.data.items(), self.ls, self.colors):
plt_lines.append(d2l.plt.plot([p.x for p in v], [p.y for p in v],
linestyle=ls, color=color)[0])
labels.append(k)
axes = self.axes if self.axes else d2l.plt.gca()
if self.xlim: axes.set_xlim(self.xlim)
if self.ylim: axes.set_ylim(self.ylim)
if not self.xlabel: self.xlabel = self.x
axes.set_xlabel(self.xlabel)
axes.set_ylabel(self.ylabel)
axes.set_xscale(self.xscale)
axes.set_yscale(self.yscale)
axes.legend(plt_lines, labels)
display.display(self.fig)
display.clear_output(wait=True)
class Module(nn.Module, d2l.HyperParameters): #@save
"""The base class of models.
Defined in :numref:`sec_oo-design`"""
def __init__(self, plot_train_per_epoch=2, plot_valid_per_epoch=1):
super().__init__()
self.save_hyperparameters()
self.board = ProgressBoard()
def loss(self, y_hat, y):
raise NotImplementedError
def __call__(self, X):
assert hasattr(self, 'net'), 'Neural network is defined'
return self.net(X)
def plot(self, key, value, train):
"""Plot a point in animation."""
assert hasattr(self, 'trainer'), 'Trainer is not inited'
self.board.xlabel = 'epoch'
if train:
x = self.trainer.train_batch_idx / \
self.trainer.num_train_batches
n = self.trainer.num_train_batches / \
self.plot_train_per_epoch
else:
x = self.trainer.epoch + 1
n = self.trainer.num_val_batches / \
self.plot_valid_per_epoch
self.board.draw(x, np.array(value),
('train_' if train else 'val_') + key,
every_n=int(n))
def training_step(self, batch):
def loss_fn(X, y):
return self.loss(self(*X), y)
l, grad = nn.value_and_grad(self, loss_fn)((batch[:-1]), batch[-1])
self.plot('loss', l.item(), train=True)
return l, grad
def validation_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('loss', l.item(), train=False)
def configure_optimizers(self):
raise NotImplementedError
def configure_optimizers(self):
"""Defined in :numref:`sec_classification`"""
return optim.SGD(learning_rate=self.lr)
class RNNScratch(d2l.Module): #@save
"""The RNN model implemented from scratch.
Defined in :numref:`sec_rnn-scratch`"""
def __init__(self, num_inputs, num_hiddens, sigma=0.01):
super().__init__()
self.save_hyperparameters()
self.W_xh = mx.random.normal(shape=(num_inputs, num_hiddens)) * sigma
self.W_hh = mx.random.normal(shape=(num_hiddens, num_hiddens)) * sigma
self.b_h = mx.zeros(num_hiddens)
def forward(self, inputs, state=None):
"""Defined in :numref:`sec_rnn-scratch`"""
if state is None:
# Initial state with shape: (batch_size, num_hiddens)
state = mx.zeros((inputs.shape[1], self.num_hiddens))
else:
state, = state
outputs = []
for X in inputs: # Shape of inputs: (num_steps, batch_size, num_inputs)
state = mx.tanh(mx.matmul(X, self.W_xh) +
mx.matmul(state, self.W_hh) + self.b_h)
outputs.append(state)
return outputs, state
def __call__(self, inputs, state=None):
"""Defined in :numref:`sec_rnn-scratch`"""
if state is None:
# Initial state with shape: (batch_size, num_hiddens)
state = mx.zeros((inputs.shape[1], self.num_hiddens))
else:
state, = state
outputs = []
for X in inputs: # Shape of inputs: (num_steps, batch_size, num_inputs)
state = mx.tanh(mx.matmul(X, self.W_xh) +
mx.matmul(state, self.W_hh) + self.b_h)
outputs.append(state)
return outputs, state
def check_len(a, n): #@save
"""Check the length of a list.
Defined in :numref:`sec_rnn-scratch`"""
assert len(a) == n, f'list\'s length {len(a)} != expected length {n}'
def check_shape(a, shape): #@save
"""Check the shape of a tensor.
Defined in :numref:`sec_rnn-scratch`"""
assert a.shape == shape, \
f'tensor\'s shape {a.shape} != expected shape {shape}'
class RNN(d2l.Module): #@save
"""The RNN model implemented with high-level APIs.
Defined in :numref:`sec_rnn-concise`"""
def __init__(self, num_inputs, num_hiddens):
super().__init__()
self.save_hyperparameters()
self.rnn = d2l.RNNScratch(num_inputs, num_hiddens)
def __call__(self, inputs, H=None):
return self.rnn(inputs, H)
class GRUScratch(d2l.Module): #@save
"""Defined in :numref:`sec_gru`"""
def __init__(self, num_inputs, num_hiddens, sigma=0.01, dropout=0):
super().__init__()
self.save_hyperparameters()
init_weight = lambda *shape: mx.random.normal(*shape) * sigma
triple = lambda: (init_weight((num_inputs, num_hiddens)),
init_weight((num_hiddens, num_hiddens)),
mx.zeros(shape=(num_hiddens,)))
self.W_xz, self.W_hz, self.b_z = triple() # Update gate
self.W_xr, self.W_hr, self.b_r = triple() # Reset gate
self.W_xh, self.W_hh, self.b_h = triple() # Candidate hidden state
def __call__(self, inputs, H=None):
"""Defined in :numref:`sec_gru`"""
if H is None:
# Initial state with shape: (batch_size, num_hiddens)
H = mx.zeros((inputs.shape[1], self.num_hiddens))
outputs = []
for X in inputs:
Z = mx.sigmoid(mx.matmul(X, self.W_xz) +
mx.matmul(H, self.W_hz) + self.b_z)
R = mx.sigmoid(mx.matmul(X, self.W_xr) +
mx.matmul(H, self.W_hr) + self.b_r)
H_tilde = mx.tanh(mx.matmul(X, self.W_xh) +
mx.matmul(R * H, self.W_hh) + self.b_h)
H = Z * H + (1 - Z) * H_tilde
H = d2l.dropout_layer(H, self.dropout)
outputs.append(H)
return outputs, H
class StackedGRUScratch(d2l.Module): #@save
"""Defined in :numref:`sec_deep_rnn`"""
def __init__(self, num_inputs, num_hiddens, num_layers, sigma=0.01, dropout=0):
super().__init__()
self.save_hyperparameters()
self.grus = nn.Sequential(*[d2l.GRUScratch(
num_inputs if i==0 else num_hiddens, num_hiddens, sigma)
for i in range(num_layers)])
def __call__(self, inputs, Hs=None):
outputs = inputs
if Hs is None: Hs = [None] * self.num_layers
for i in range(self.num_layers):
outputs, Hs[i] = self.grus.layers[i](outputs, Hs[i])
outputs = mx.stack(outputs, axis=0)
return outputs, Hs
class GRU(d2l.RNN): #@save
"""The multilayer GRU model.
Defined in :numref:`sec_deep_rnn`"""
def __init__(self, num_inputs, num_hiddens, num_layers, dropout=0):
d2l.Module.__init__(self)
self.save_hyperparameters()
self.rnn = d2l.StackedGRUScratch(num_inputs, num_hiddens, num_layers, sigma=0.01, dropout=dropout)
def dropout_layer(X, dropout): #@save
"""Defined in :numref:`sec_dropout`"""
assert 0 <= dropout <= 1
if dropout == 1: return mx.zeros_like(X)
mask = (mx.random.uniform(shape=X.shape) > dropout).astype(mx.float32)
return mask * X / (1.0 - dropout)
def add_to_class(Class): #@save
"""Register functions as methods in created class.
Defined in :numref:`sec_oo-design`"""
def wrapper(obj):
setattr(Class, obj.__name__, obj)
return wrapper
def num_gpus(): #@save
"""Get the number of available GPUs.
Defined in :numref:`sec_use_gpu`"""
return 1
def gpu(i=0): #@save
"""Get a GPU device.
Defined in :numref:`sec_use_gpu`"""
mx.set_default_device(mx.gpu)
return mx.default_device()
def try_gpu(i=0): #@save
"""Return gpu(i) if exists, otherwise return cpu().
Defined in :numref:`sec_use_gpu`"""
return gpu(0)
class Trainer(d2l.HyperParameters): #@save
"""The base class for training models with data.
Defined in :numref:`subsec_oo-design-models`"""
def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0):
self.save_hyperparameters()
assert num_gpus == 0, 'No GPU support yet'
def prepare_data(self, data):
self.train_dataloader = data.train_dataloader()
self.val_dataloader = data.val_dataloader()
self.num_train_batches = len(self.train_dataloader)
self.num_val_batches = (len(self.val_dataloader)
if self.val_dataloader is not None else 0)
def prepare_model(self, model):
model.trainer = self
model.board.xlim = [0, self.max_epochs]
self.model = model
def fit(self, model, data):
self.prepare_data(data)
self.prepare_model(model)
self.optim = model.configure_optimizers()
self.epoch = 0
self.train_batch_idx = 0
self.val_batch_idx = 0
for self.epoch in range(self.max_epochs):
self.fit_epoch()
def fit_epoch(self):
raise NotImplementedError
def prepare_batch(self, batch):
"""Defined in :numref:`sec_linear_scratch`"""
return batch
def fit_epoch(self):
"""Defined in :numref:`sec_linear_scratch`"""
self.model.train(True)
for batch in self.train_dataloader:
loss, grads = self.model.training_step(self.prepare_batch(batch))
if self.gradient_clip_val > 0:
grads = self.clip_gradients(self.gradient_clip_val, grads)
self.optim.update(model=self.model, gradients=grads)
mx.eval(self.model.parameters())
self.train_batch_idx += 1
if self.val_dataloader is None:
return
self.model.eval()
for batch in self.val_dataloader:
self.model.validation_step(self.prepare_batch(batch))
self.val_batch_idx += 1
def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0):
"""Defined in :numref:`sec_use_gpu`"""
self.save_hyperparameters()
self.gpus = [d2l.gpu(i) for i in range(min(num_gpus, d2l.num_gpus()))]
def prepare_batch(self, batch):
"""Defined in :numref:`sec_use_gpu`"""
if self.gpus:
gpu()
batch = [a for a in batch]
return batch
def prepare_model(self, model):
"""Defined in :numref:`sec_use_gpu`"""
model.trainer = self
model.board.xlim = [0, self.max_epochs]
if self.gpus:
gpu()
self.model = model
def clip_gradients(self, grad_clip_val, grads):
"""Defined in :numref:`sec_rnn-scratch`
Defined in :numref:`sec_rnn-scratch`"""
grad_leaves = mlx.utils.tree_flatten(grads)
norm = mx.sqrt(sum((x[1] ** 2).sum() for x in grad_leaves))
clip = lambda grad: mx.where(norm < grad_clip_val, grad, grad * (grad_clip_val / norm))
return mlx.utils.tree_map(clip, grads)
def download_new(url, folder='../data', sha1_hash=None): #@save
"""Download a file to folder and return the local filepath.
Defined in :numref:`sec_utils`"""
if not url.startswith('http'):
# For back compatability
url, sha1_hash = DATA_HUB[url]
os.makedirs(folder, exist_ok=True)
fname = os.path.join(folder, url.split('/')[-1])
# Check if hit cache
if os.path.exists(fname) and sha1_hash:
sha1 = hashlib.sha1()
with open(fname, 'rb') as f:
while True:
data = f.read(1048576)
if not data:
break
sha1.update(data)
if sha1.hexdigest() == sha1_hash:
return fname
# Download
print(f'Downloading {fname} from {url}...')
r = requests.get(url, stream=True, verify=True)
with open(fname, 'wb') as f:
f.write(r.content)
return fname
def extract(filename, folder=None): #@save
"""Extract a zip/tar file into folder.
Defined in :numref:`sec_utils`"""
base_dir = os.path.dirname(filename)
_, ext = os.path.splitext(filename)
assert ext in ('.zip', '.tar', '.gz'), 'Only support zip/tar files.'
if ext == '.zip':
fp = zipfile.ZipFile(filename, 'r')
else:
fp = tarfile.open(filename, 'r')
if folder is None:
folder = base_dir
fp.extractall(folder)
class Classifier(d2l.Module): #@save
"""The base class of classification models.
Defined in :numref:`sec_classification`"""
def validation_step(self, batch):
Y_hat = self(*batch[:-1])
self.plot('loss', self.loss(Y_hat, batch[-1]), train=False)
self.plot('acc', self.accuracy(Y_hat, batch[-1]), train=False)
def accuracy(self, Y_hat, Y, averaged=True):
"""Compute the number of correct predictions.
Defined in :numref:`sec_classification`"""
Y_hat = Y_hat.reshape((-1, Y_hat.shape[-1]))
preds = Y_hat.argmax(axis=1).astype(Y.dtype)
compare = (preds == Y.reshape(-1)).astype(mx.float32)
return compare.mean() if averaged else compare
def loss(self, Y_hat, Y, averaged=True):
"""Defined in :numref:`sec_softmax_concise`"""
Y_hat = Y_hat.reshape((-1, Y_hat.shape[-1]))
Y = Y.reshape((-1,))
return nn.losses.cross_entropy(
Y_hat, Y, reduction='mean' if averaged else 'none')
def layer_summary(self, X_shape):
"""Defined in :numref:`sec_lenet`"""
X = mx.random.normal(shape=(X_shape))
for layer in self.net.layers:
X = layer(X)
print(layer.__class__.__name__, 'output shape:\t', X.shape)
class MTFraEng(d2l.DataModule): #@save
"""The English-French dataset.
Defined in :numref:`sec_machine_translation`"""
def _download_new(self):
d2l.extract(d2l.download_new(
d2l.DATA_URL+'fra-eng.zip', self.root,
'94646ad1522d915e7b0f9296181140edcf86a4f5'))
with open(self.root + '/fra-eng/fra.txt', encoding='utf-8') as f:
return f.read()
def _preprocess(self, text):
"""Defined in :numref:`sec_machine_translation`"""
# Replace non-breaking space with space
text = text.replace('\u202f', ' ').replace('\xa0', ' ')
# Insert space between words and punctuation marks
no_space = lambda char, prev_char: char in ',.!?' and prev_char != ' '
out = [' ' + char if i > 0 and no_space(char, text[i - 1]) else char
for i, char in enumerate(text.lower())]
return ''.join(out)
def _tokenize(self, text, max_examples=None):
"""Defined in :numref:`sec_machine_translation`"""
src, tgt = [], []
for i, line in enumerate(text.split('\n')):
if max_examples and i > max_examples: break
parts = line.split('\t')
if len(parts) == 2:
# Skip empty tokens
src.append([t for t in f'{parts[0]} <eos>'.split(' ') if t])
tgt.append([t for t in f'{parts[1]} <eos>'.split(' ') if t])
return src, tgt
def __init__(self, batch_size, num_steps=9, num_train=512, num_val=128):
"""Defined in :numref:`sec_machine_translation`"""
super(MTFraEng, self).__init__()
self.save_hyperparameters()
self.arrays, self.src_vocab, self.tgt_vocab = self._build_arrays(
self._download_new())
def __init__(self, batch_size, num_steps=9, num_train=512, num_val=128):
"""Defined in :numref:`sec_machine_translation`"""
super(MTFraEng, self).__init__()
self.save_hyperparameters()
self.arrays, self.src_vocab, self.tgt_vocab = self._build_arrays(
self._download_new())
def _build_arrays(self, raw_text, src_vocab=None, tgt_vocab=None):
"""Defined in :numref:`subsec_mt_data_loading`"""
def _build_array(sentences, vocab, is_tgt=False):
pad_or_trim = lambda seq, t: (
seq[:t] if len(seq) > t else seq + ['<pad>'] * (t - len(seq)))
sentences = [pad_or_trim(s, self.num_steps) for s in sentences]
if is_tgt:
sentences = [['<bos>'] + s for s in sentences]
if vocab is None:
vocab = d2l.Vocab(sentences, min_freq=2)
array = mx.array([vocab[s] for s in sentences])
valid_len = (array != vocab['<pad>']).astype(mx.int32).sum(1)
return array, vocab, valid_len
src, tgt = self._tokenize(self._preprocess(raw_text),
self.num_train + self.num_val)
src_array, src_vocab, src_valid_len = _build_array(src, src_vocab)
tgt_array, tgt_vocab, _ = _build_array(tgt, tgt_vocab, True)
return ((src_array, tgt_array[:,:-1], src_valid_len, tgt_array[:,1:]),
src_vocab, tgt_vocab)
def get_dataloader(self, train):
"""Defined in :numref:`subsec_mt_data_loading`"""
idx = slice(0, self.num_train) if train else slice(self.num_train, None)
return self.get_tensorloader(self.arrays, train, idx)
def build(self, src_sentences, tgt_sentences):
"""Defined in :numref:`subsec_mt_data_loading`"""
raw_text = '\n'.join([src + '\t' + tgt for src, tgt in zip(
src_sentences, tgt_sentences)])
arrays, _, _ = self._build_arrays(
raw_text, self.src_vocab, self.tgt_vocab)
return arrays
class Encoder(d2l.Module): #@save
"""编码器-解码器架构的基本编码器接口"""
def __init__(self):
super().__init__()
def __call__(self, X, *args):
raise NotImplementedError
class Decoder(d2l.Module): #@save
"""编码器-解码器架构的基本解码器接口
Defined in :numref:`sec_encoder-decoder`"""
def __init__(self):
super().__init__()
def init_state(self, enc_all_outputs, *args):
raise NotImplementedError
def __call__(self, X, state):
raise NotImplementedError
class EncoderDecoder(d2l.Classifier): #@save
"""编码器-解码器架构的基类
Defined in :numref:`sec_encoder-decoder`"""
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def __call__(self, enc_X, dec_X, *args):
enc_all_outputs = self.encoder(enc_X, *args)
dec_state = self.decoder.init_state(enc_all_outputs, *args)
return self.decoder(dec_X, dec_state)[0]
def predict_step(self, batch, device, num_steps,
save_attention_weights=False):
"""Defined in :numref:`sec_seq2seq_decoder`"""
src, tgt, src_valid_len, _ = batch
enc_all_outputs = self.encoder(src, src_valid_len)
dec_state = self.decoder.init_state(enc_all_outputs, src_valid_len)
outputs, attention_weights = [tgt[:, (0)].reshape(-1 ,1), ], []
for _ in range(num_steps):
Y, dec_state = self.decoder(outputs[-1], dec_state)
outputs.append(Y.argmax(2))
# Save attention weights (to be covered later)
if save_attention_weights:
attention_weights.append(self.decoder.attention_weights)
return mx.concatenate(outputs[1:], 1), attention_weights
def init_seq2seq(array): #@save
"""Defined in :numref:`sec_seq2seq`"""
if array.ndim > 1:
weight_fn = nn.init.glorot_uniform()
array = weight_fn(array)
return array
class Seq2SeqEncoder(d2l.Encoder): #@save
"""用于序列到序列学习的循环神经网络编码器
Defined in :numref:`sec_seq2seq`"""
def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_size)
self.rnn = d2l.GRU(embed_size, num_hiddens, num_layers, dropout)
# print("before init:", tree_flatten(self.parameters()))
for module in self.modules():
if isinstance(module, nn.Linear) or isinstance(module, d2l.GRU):
module.update(mlx.utils.tree_map(lambda x: init_seq2seq(x), module.parameters()))
def __call__(self, X, *args):
# X shape: (batch_size, num_steps)
embs = self.embedding(X.T.astype(mx.int64))
# embs shape: (num_steps, batch_size, embed_size)
outputs, state = self.rnn(embs)
state = mx.array(state)
# outputs shape: (num_steps, batch_size, num_hiddens)
# state shape: (num_layers, batch_size, num_hiddens)
return outputs, state
class Seq2SeqDecoder(d2l.Decoder): #@save
"""The RNN decoder for sequence to sequence learning.
Defined in :numref:`sec_seq2seq`"""
def __init__(self, num_inputs, vocab_size, embed_size, num_hiddens, num_layers,
dropout=0):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_size)
self.rnn = d2l.GRU(embed_size+num_hiddens, num_hiddens,
num_layers, dropout)
self.dense = nn.Linear(num_inputs, vocab_size)
# print("before init:", tree_flatten(self.parameters()))
for module in self.modules():
if isinstance(module, nn.Linear) or isinstance(module, d2l.GRU):
module.update(mlx.utils.tree_map(lambda x: init_seq2seq(x), module.parameters()))
def init_state(self, enc_all_outputs, *args):
return enc_all_outputs
def __call__(self, X, state):
# X shape: (batch_size, num_steps)
# embs shape: (num_steps, batch_size, embed_size)
embs = self.embedding(X.T.astype(mx.int32))
enc_output, hidden_state = state
# context shape: (batch_size, num_hiddens)
context = enc_output[-1]
# Broadcast context to (num_steps, batch_size, num_hiddens)
context = mx.tile(context, (embs.shape[0], 1, 1))
# Concat at the feature dimension
embs_and_context = mx.concatenate((embs, context), -1)
outputs, hidden_state = self.rnn(embs_and_context, hidden_state)
outputs = self.dense(outputs).swapaxes(0, 1)
# outputs shape: (batch_size, num_steps, vocab_size)
# hidden_state shape: (num_layers, batch_size, num_hiddens)
return outputs, [enc_output, hidden_state]
class Seq2Seq(d2l.EncoderDecoder): #@save
"""The RNN encoder--decoder for sequence to sequence learning.
Defined in :numref:`sec_seq2seq_decoder`"""
def __init__(self, encoder, decoder, tgt_pad, lr):
super().__init__(encoder, decoder)
self.save_hyperparameters()
def validation_step(self, batch):
Y_hat = self(*batch[:-1])
self.plot('loss', self.loss(Y_hat, batch[-1]), train=False)
def configure_optimizers(self):
# Adam optimizer is used here
return optim.Adam(learning_rate=self.lr)
def loss(self, Y_hat, Y):
"""Defined in :numref:`sec_seq2seq_decoder`"""
l = super(Seq2Seq, self).loss(Y_hat, Y, averaged=False)
mask = (Y.reshape(-1) != self.tgt_pad).astype(mx.float32)
return (l * mask).sum() / mask.sum()
def bleu(pred_seq, label_seq, k): #@save
"""计算BLEU
Defined in :numref:`sec_seq2seq_decoder`"""
pred_tokens, label_tokens = pred_seq.split(' '), label_seq.split(' ')
len_pred, len_label = len(pred_tokens), len(label_tokens)
score = math.exp(min(0, 1 - len_label / len_pred))
for n in range(1, min(k, len_pred) + 1):
num_matches, label_subs = 0, collections.defaultdict(int)
for i in range(len_label - n + 1):
label_subs[' '.join(label_tokens[i: i + n])] += 1
for i in range(len_pred - n + 1):
if label_subs[' '.join(pred_tokens[i: i + n])] > 0:
num_matches += 1
label_subs[' '.join(pred_tokens[i: i + n])] -= 1
score *= math.pow(num_matches / (len_pred - n + 1), math.pow(0.5, n))
return score
class Dataset: #@save
def __init__(self, *tensors):
"""Defined in :numref:`sec_fashion_mnist`"""
assert all(
tensors[0].shape[0] == tensor.shape[0] for tensor in tensors
), "Size mismatch between tensors"
self.tensors = tensors
def __getitem__(self, index):
return tuple(mx.array(tensor[index]) for tensor in self.tensors)
def __len__(self):
return self.tensors[0].shape[0]
class DataLoader: #@save
def __init__(self, dataset, batch_size=1, shuffle=False, drop_last=False):
"""Defined in :numref:`sec_fashion_mnist`"""
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
self.indices = list(range(len(dataset)))
self.current_index = 0
if self.shuffle:
random.shuffle(self.indices)
def __iter__(self):
self.current_index = 0
if self.shuffle:
random.shuffle(self.indices)
return self
def __next__(self):
if self.current_index >= len(self.indices):
raise StopIteration
end_index = self.current_index + self.batch_size
if end_index > len(self.indices):
if self.drop_last:
raise StopIteration
else:
end_index = len(self.indices)
batch_indices = self.indices[self.current_index:end_index]
batch = [self.dataset[i] for i in batch_indices]
self.current_index = end_index
return self.collate_fn(batch)
def __len__(self):
if self.drop_last:
return len(self.dataset) // self.batch_size
else:
return math.ceil(len(self.dataset) / self.batch_size)
def collate_fn(self, batch):
if isinstance(batch[0], tuple):
if (len(batch[0])) == 2:
data, targets = zip(*batch)
data = mx.array(data)
targets = mx.array(targets)
return data, targets
if (len(batch[0])) == 4:
data, decoder_input, src_valid_len, targets = zip(*batch)
data = mx.array(data)
decoder_input = mx.array(decoder_input)
src_valid_len = mx.array(src_valid_len)
targets = mx.array(targets)
return data, decoder_input, src_valid_len, targets
return mx.array(batch)
class AdaptiveAvgPool2d(nn.Module): #@save
"""Applies a 2D adaptive average pooling over an input signal.
The output spatial dimensions are specified by `output_size`.
This implementation uses a standard `mlx.nn.AvgPool2d` layer with
dynamically calculated kernel size and stride to achieve the target
output size. It is primarily designed for downsampling (input size >= output size).
Args:
output_size (int or tuple): The target output size of the image
of the form H x W. Can be a single integer H to specify H x H,
or a tuple of two integers (H, W).
"""
def __init__(self, output_size):
super().__init__()
if isinstance(output_size, int):
self.output_h = output_size
self.output_w = output_size
elif isinstance(output_size, tuple) and len(output_size) == 2 and \
isinstance(output_size[0], int) and isinstance(output_size[1], int):
self.output_h = output_size[0]
self.output_w = output_size[1]
else:
raise ValueError(
"output_size must be an int or a tuple of two positive ints"
)
if self.output_h <= 0 or self.output_w <= 0:
raise ValueError("output_size dimensions must be positive")
def __call__(self, x: mx.array) -> mx.array:
"""
Forward pass for AdaptiveAvgPool2d.
Args:
x (mx.array): Input tensor of shape (N, H_in, W_in, C).
MLX typically uses channels-last format.
Returns:
mx.array: Output tensor of shape (N, H_out, W_out, C).
"""
input_h = x.shape[1]
input_w = x.shape[2]
if self.output_h == input_h and self.output_w == input_w:
return x
# Handle the common Global Average Pooling case efficiently
if self.output_h == 1 and self.output_w == 1:
return mx.mean(x, axis=(1, 2), keepdims=True)
# This implementation relies on nn.AvgPool2d and is suited for downsampling.
if input_h < self.output_h or input_w < self.output_w:
raise ValueError(
f"Input spatial size ({input_h}x{input_w}) is smaller than target output size "
f"({self.output_h}x{self.output_w}) in at least one dimension. "
"This AdaptiveAvgPool2d implementation using nn.AvgPool2d "
"requires input dimensions to be greater than or equal to output dimensions."
)
# Calculate stride
# stride_h = math.floor(input_h / self.output_h) # floor is implicit with // for positive
stride_h = input_h // self.output_h
stride_w = input_w // self.output_w
# Calculate kernel size using the formula: K = I - (O - 1) * S
# This ensures that an AvgPool2d operation with this kernel and stride
# will produce an output of the target size O.
kernel_h = input_h - (self.output_h - 1) * stride_h
kernel_w = input_w - (self.output_w - 1) * stride_w
# Ensure kernel and stride are valid (should be if input_h >= output_h etc.)
if kernel_h <= 0 or kernel_w <= 0 or stride_h <= 0 or stride_w <= 0:
# This case should ideally be caught by the input_h < self.output_h checks,
# or if output_h/output_w are non-positive (checked in __init__).
raise RuntimeError(
f"Calculated invalid pooling parameters: "
f"kernel=({kernel_h},{kernel_w}), stride=({stride_h},{stride_w}). "
f"Input: ({input_h},{input_w}), Output: ({self.output_h},{self.output_w})"
)
pool_layer = nn.AvgPool2d(
kernel_size=(kernel_h, kernel_w),
stride=(stride_h, stride_w),
padding=0 # Adaptive pooling typically does not involve explicit padding
)
return pool_layer(x)
class resnet18(nn.Module): #@save
"""稍加修改的ResNet-18模型
Defined in :numref:`sec_multi_gpu_concise`"""
def __init__(self, num_classes, in_channels=1):
super(resnet18, self).__init__()
def resnet_block(in_channels, out_channels, num_residuals,
first_block=False):
blk = []
for i in range(num_residuals):
if i == 0 and not first_block:
blk.append(d2l.Residual(in_channels, out_channels,
use_1x1conv=True, strides=2))
else:
blk.append(d2l.Residual(out_channels, out_channels))
#return nn.Sequential(*blk)
return blk
b1 = nn.Sequential(nn.Conv2d(in_channels, 64, kernel_size=7, stride=2,padding=3),
nn.BatchNorm(64), nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))
self.layers = nn.Sequential(b1, b2, b3, b4, b5,
# nn.AvgPool2d((3,3)),
d2l.AdaptiveAvgPool2d((1, 1)),
d2l.Flatten(),
nn.Linear(512, num_classes))
def __call__(self, x):
# 该模型使用了更小的卷积核、步长和填充,而且删除了最大汇聚层
x = self.layers(x)
return x