Files

58 lines
1.4 KiB
JavaScript

import { Registry } from 'mixly';
import * as tf from '@tensorflow/tfjs';
const modelsValueRegistry = new Registry();
const customFetch = function (path) {
let result = {
ok: false,
buffer: null,
json: function () {
const decoder = new TextDecoder('utf-8');
const jsonText = decoder.decode(this.buffer);
return JSON.parse(jsonText);
},
arrayBuffer: function () {
return this.buffer;
}
}
if (!modelsValueRegistry.hasKey(path)) {
return result;
}
result.ok = true;
result.buffer = modelsValueRegistry.getItem(path);
return result;
};
const tensorflow = {};
tensorflow.modelsValue = {};
tensorflow.loadGraphModel = async function (path) {
const model = await tf.loadGraphModel(path, {
fromTFHub: false,
fetchFunc: (...args) => {
return customFetch(args[0]);
}
});
return model;
};
tensorflow.loadLayersModel = async function (path) {
const model = await tf.loadLayersModel(path, {
fromTFHub: false,
fetchFunc: (...args) => {
return customFetch(args[0]);
}
});
return model;
};
tensorflow.setModelsValue = function (path, value) {
if (modelsValueRegistry.hasKey(path)) {
modelsValueRegistry.unregister(path);
}
modelsValueRegistry.register(path, value);
tensorflow.modelsValue[path] = value;
};
window.tensorflow = tensorflow;