You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
189 lines
5.3 KiB
JavaScript
189 lines
5.3 KiB
JavaScript
|
|
// Experimental
|
|
|
|
var pickle = pickle || {};
|
|
var python = python || require('./python');
|
|
var zip = zip || require('./zip');
|
|
|
|
pickle.ModelFactory = class {
|
|
|
|
match(context) {
|
|
const stream = context.stream;
|
|
const signature = [ 0x80, undefined, 0x8a, 0x0a, 0x6c, 0xfc, 0x9c, 0x46, 0xf9, 0x20, 0x6a, 0xa8, 0x50, 0x19 ];
|
|
if (signature.length <= stream.length && stream.peek(signature.length).every((value, index) => signature[index] === undefined || signature[index] === value)) {
|
|
// Reject PyTorch models with .pkl file extension.
|
|
return undefined;
|
|
}
|
|
const obj = context.open('pkl');
|
|
if (obj !== undefined) {
|
|
return 'pickle';
|
|
}
|
|
return undefined;
|
|
}
|
|
|
|
open(context) {
|
|
return new Promise((resolve) => {
|
|
let format = 'Pickle';
|
|
const obj = context.open('pkl');
|
|
if (obj === null || obj === undefined) {
|
|
context.exception(new pickle.Error("Unknown Pickle null object in '" + context.identifier + "'."));
|
|
}
|
|
else if (Array.isArray(obj)) {
|
|
if (obj.length > 0 && obj[0] && obj.every((item) => item && item.__class__ && obj[0].__class__ && item.__class__.__module__ === obj[0].__class__.__module__ && item.__class__.__name__ === obj[0].__class__.__name__)) {
|
|
const type = obj[0].__class__.__module__ + "." + obj[0].__class__.__name__;
|
|
context.exception(new pickle.Error("Unknown Pickle '" + type + "' array object in '" + context.identifier + "'."));
|
|
}
|
|
else {
|
|
context.exception(new pickle.Error("Unknown Pickle array object in '" + context.identifier + "'."));
|
|
}
|
|
}
|
|
else if (obj && obj.__class__) {
|
|
const formats = new Map([
|
|
[ 'cuml.ensemble.randomforestclassifier.RandomForestClassifier', 'cuML' ]
|
|
]);
|
|
const type = obj.__class__.__module__ + "." + obj.__class__.__name__;
|
|
if (formats.has(type)) {
|
|
format = formats.get(type);
|
|
}
|
|
else {
|
|
context.exception(new pickle.Error("Unknown Pickle type '" + type + "' in '" + context.identifier + "'."));
|
|
}
|
|
}
|
|
else {
|
|
context.exception(new pickle.Error("Unknown Pickle object in '" + context.identifier + "'."));
|
|
}
|
|
resolve(new pickle.Model(obj, format));
|
|
});
|
|
}
|
|
};
|
|
|
|
pickle.Model = class {
|
|
|
|
constructor(value, format) {
|
|
this._format = format;
|
|
this._graphs = [ new pickle.Graph(value) ];
|
|
}
|
|
|
|
get format() {
|
|
return this._format;
|
|
}
|
|
|
|
get graphs() {
|
|
return this._graphs;
|
|
}
|
|
};
|
|
|
|
pickle.Graph = class {
|
|
|
|
constructor(obj) {
|
|
this._inputs = [];
|
|
this._outputs = [];
|
|
this._nodes = [];
|
|
|
|
if (Array.isArray(obj) && obj.every((item) => item.__class__)) {
|
|
for (const item of obj) {
|
|
this._nodes.push(new pickle.Node(item));
|
|
}
|
|
}
|
|
else if (obj && obj instanceof Map) {
|
|
for (const entry of obj) {
|
|
this._nodes.push(new pickle.Node(entry[1], entry[0]));
|
|
}
|
|
}
|
|
else if (obj && obj.__class__) {
|
|
this._nodes.push(new pickle.Node(obj));
|
|
}
|
|
else if (obj && Object(obj) === obj) {
|
|
this._nodes.push(new pickle.Node(obj));
|
|
}
|
|
}
|
|
|
|
get inputs() {
|
|
return this._inputs;
|
|
}
|
|
|
|
get outputs() {
|
|
return this._outputs;
|
|
}
|
|
|
|
get nodes() {
|
|
return this._nodes;
|
|
}
|
|
};
|
|
|
|
pickle.Node = class {
|
|
|
|
constructor(obj, name) {
|
|
this._name = name || '';
|
|
this._inputs = [];
|
|
this._outputs = [];
|
|
this._attributes = [];
|
|
if (Array.isArray(obj)) {
|
|
this._type = { name: 'List' };
|
|
this._attributes.push(new pickle.Attribute('value', obj));
|
|
}
|
|
else {
|
|
const type = obj.__class__ ? obj.__class__.__module__ + '.' + obj.__class__.__name__ : 'Object';
|
|
this._type = { name: type };
|
|
for (const key of Object.keys(obj)) {
|
|
const value = obj[key];
|
|
this._attributes.push(new pickle.Attribute(key, value));
|
|
}
|
|
}
|
|
}
|
|
|
|
get type() {
|
|
return this._type;
|
|
}
|
|
|
|
get name() {
|
|
return this._name;
|
|
}
|
|
|
|
get inputs() {
|
|
return this._inputs;
|
|
}
|
|
|
|
get outputs() {
|
|
return this._outputs;
|
|
}
|
|
|
|
get attributes() {
|
|
return this._attributes;
|
|
}
|
|
};
|
|
|
|
pickle.Attribute = class {
|
|
|
|
constructor(name, value) {
|
|
this._name = name;
|
|
this._value = value;
|
|
if (value && value.__class__) {
|
|
this._type = value.__class__.__module__ + '.' + value.__class__.__name__;
|
|
}
|
|
}
|
|
|
|
get name() {
|
|
return this._name;
|
|
}
|
|
|
|
get value() {
|
|
return this._value;
|
|
}
|
|
|
|
get type() {
|
|
return this._type;
|
|
}
|
|
};
|
|
|
|
pickle.Error = class extends Error {
|
|
|
|
constructor(message) {
|
|
super(message);
|
|
this.name = 'Error loading Pickle model.';
|
|
}
|
|
};
|
|
|
|
if (typeof module !== 'undefined' && typeof module.exports === 'object') {
|
|
module.exports.ModelFactory = pickle.ModelFactory;
|
|
} |